In [None]:
%load_ext autoreload
%autoreload 2

import os
from dotenv import load_dotenv
from pathlib import Path
from research_tools.gpu import get_gpus_available

load_dotenv()


hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

n_gpus = 1

gpus_available = get_gpus_available()
n_gpus = min(n_gpus, len(gpus_available))
gpus = gpus_available[:n_gpus]

assert n_gpus > 0, "No GPUs available"

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpus])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional
import torch
from research_tools.utils import set_seed

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mnt/align4_drive/data/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/.no_exist/892b3d7a7b1cf10c7a701c60881cd93df615734c/chat_template.jinja'


In [6]:
from datasets import load_dataset
from relearn.datasets.utils import (
    load_dataset as local_load_dataset,
    DATASETS_DICT,
    Datasets,
)
from relearn.datasets.corpus import process as process_corpus
from relearn.datasets.mcq import process as process_mcq

dataset_config = DATASETS_DICT[Datasets.WMDP]

# retain_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

data_dir = Path("../data")


def get_dataset(train_files: List[str], val_files: List[str], max_length: int):
    train = local_load_dataset(data_dir, train_files)
    val = local_load_dataset(data_dir, val_files)
    train_records = process_corpus(train, tokenizer, max_length)
    val_records = process_mcq(val, tokenizer, max_length)
    mcq_records = process_mcq(val, tokenizer, max_length, expand_choices=False)
    return {
        "corpus": train_records,
        "mcq": mcq_records,
        "val": val_records,
    }


n_val_files = 4
max_length = 512

unlearn_files = dataset_config["unlearn_files"]
val_unlearn_files = dataset_config["val_unlearn_files"]
retain_files = dataset_config["retain_files"]
val_retain_files = dataset_config["val_retain_files"]

store = {
    "A": get_dataset(
        unlearn_files[:n_val_files], val_unlearn_files[:n_val_files], max_length
    ),
    "B": get_dataset(
        unlearn_files[n_val_files:], val_unlearn_files[n_val_files:], max_length
    ),
    "retain": get_dataset(retain_files, val_retain_files, max_length),
}

Filter:   0%|          | 0/1884 [00:00<?, ? examples/s]

Map:   0%|          | 0/1884 [00:00<?, ? examples/s]

Filter:   0%|          | 0/471 [00:00<?, ? examples/s]

Map:   0%|          | 0/471 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [7]:
eval_dict = {k: v["val"] for k, v in store.items()}

In [5]:
import wandb

wandb.login()

config = {
    "model_id": model_id,
    "magnitude": 6.5,
    "lr": 1e-5,
    "n_epochs": 12,
    "forget_alphas": {"A": 0.39422},
    "retain_alphas": {"B": 13.51609, "retain": 1},
    "datasets_config": dataset_config,
}


run = wandb.init(
    project="relearn", config=config, tags=["rmu", "debug"], entity="12tqian"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m12tqian[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
from relearn.unlearn.rmu import train_rmu

config = run.config

model = train_rmu(
    model,
    {"A": store["A"]["corpus"]},
    {"B": store["B"]["corpus"], "retain": store["retain"]["corpus"]},
    eval_records_dict=eval_dict,
    n_epochs=config["n_epochs"],
    magnitude=config["magnitude"],
    lr=config["lr"],
    forget_alphas=config["forget_alphas"],
    retain_alphas=config["retain_alphas"],
    eval_at_start=False,
    max_batches=None,
    verbose=True,
    debug=False,
    tokenizer=tokenizer,
)

In [7]:
from pathlib import Path

path = Path("../models/wmdp") / "unlearn_A_retain_B"
# os.makedirs(path, exist_ok=True)
# model.save_pretrained(path)

# torch.save(model.state_dict(), path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(
    device
)

# assert False

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
from relearn.unlearn.rmu import train_rmu


model = train_rmu(
    model,
    {"B": store["B"]["corpus"]},
    {"retain": store["retain"]["corpus"]},
    eval_records_dict=eval_dict,
    n_epochs=5,
    magnitude=6.5,
    lr=1e-5,
    forget_alphas={
        "B": 1,
    },
    retain_alphas={
        "retain": 1,
    },
    eval_at_start=False,
    verbose=True,
    max_batches=10,
    tokenizer=tokenizer,
)

  forget_loss = torch.nn.functional.mse_loss(
100%|██████████| 10/10 [00:04<00:00,  2.31it/s, B/forget_loss=0.0352, retain/retain_loss=0.000328]
  0%|          | 0/314 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 314/314 [00:20<00:00, 15.32it/s]
100%|██████████| 79/79 [00:05<00:00, 14.69it/s]
100%|██████████| 393/393 [00:42<00:00,  9.25it/s]


In [9]:
from relearn.attacks.rtt import train_rtt

# relearm only A
model = train_rtt(
    model,
    10,
    store["A"]["mcq"],
    eval_dict,
    batch_size=2,
    lr=1e-6,
    eval_at_start=True,
    grad_accum_steps=2,
)

  0%|          | 0/314 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 314/314 [00:21<00:00, 14.90it/s]


Start A Accuracy: 0.1735668789808917


100%|██████████| 79/79 [00:05<00:00, 14.68it/s]


Start B Accuracy: 0.28662420382165604


51it [00:09,  5.64it/s]

Epoch 0, Step 49, Loss 0.2362062931060791


101it [00:19,  5.45it/s]

Epoch 0, Step 99, Loss 0.23506104946136475


151it [00:28,  5.52it/s]

Epoch 0, Step 149, Loss 0.22673895955085754


201it [00:37,  5.61it/s]

Epoch 0, Step 199, Loss 0.08732502907514572


251it [00:46,  5.51it/s]

Epoch 0, Step 249, Loss 0.11531460285186768


301it [00:56,  5.59it/s]

Epoch 0, Step 299, Loss 0.165248304605484


314it [00:58,  5.36it/s]
100%|██████████| 314/314 [00:20<00:00, 15.15it/s]


Epoch 0 A Accuracy: 0.40764331210191085


100%|██████████| 79/79 [00:05<00:00, 14.03it/s]


Epoch 0 B Accuracy: 0.3375796178343949


51it [00:09,  5.75it/s]

Epoch 1, Step 49, Loss 0.09760218113660812


101it [00:18,  5.62it/s]

Epoch 1, Step 99, Loss 0.1078425943851471


151it [00:27,  5.61it/s]

Epoch 1, Step 149, Loss 0.11090829223394394


201it [00:36,  5.46it/s]

Epoch 1, Step 199, Loss 0.05048422887921333


251it [00:46,  5.49it/s]

Epoch 1, Step 249, Loss 0.13568705320358276


301it [00:55,  5.72it/s]

Epoch 1, Step 299, Loss 0.1353950947523117


314it [00:57,  5.42it/s]
100%|██████████| 314/314 [00:20<00:00, 15.08it/s]


Epoch 1 A Accuracy: 0.46496815286624205


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 1 B Accuracy: 0.34394904458598724


51it [00:09,  5.19it/s]

Epoch 2, Step 49, Loss 0.05741976574063301


101it [00:18,  5.61it/s]

Epoch 2, Step 99, Loss 0.10379991680383682


151it [00:27,  5.52it/s]

Epoch 2, Step 149, Loss 0.09365157037973404


201it [00:37,  5.42it/s]

Epoch 2, Step 199, Loss 0.16939611732959747


251it [00:46,  5.68it/s]

Epoch 2, Step 249, Loss 0.11671654880046844


301it [00:55,  5.56it/s]

Epoch 2, Step 299, Loss 0.1424107402563095


314it [00:58,  5.39it/s]
100%|██████████| 314/314 [00:20<00:00, 15.12it/s]


Epoch 2 A Accuracy: 0.535031847133758


100%|██████████| 79/79 [00:05<00:00, 14.56it/s]


Epoch 2 B Accuracy: 0.37579617834394907


51it [00:09,  5.30it/s]

Epoch 3, Step 49, Loss 0.1300419420003891


101it [00:18,  5.65it/s]

Epoch 3, Step 99, Loss 0.09453134983778


151it [00:28,  5.49it/s]

Epoch 3, Step 149, Loss 0.032910268753767014


201it [00:37,  5.55it/s]

Epoch 3, Step 199, Loss 0.04037759080529213


251it [00:46,  5.50it/s]

Epoch 3, Step 249, Loss 0.054063934832811356


301it [00:55,  5.63it/s]

Epoch 3, Step 299, Loss 0.14921537041664124


314it [00:58,  5.38it/s]
100%|██████████| 314/314 [00:20<00:00, 15.11it/s]


Epoch 3 A Accuracy: 0.64171974522293


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 3 B Accuracy: 0.3630573248407643


51it [00:09,  5.45it/s]

Epoch 4, Step 49, Loss 0.03481728583574295


101it [00:18,  5.65it/s]

Epoch 4, Step 99, Loss 0.12688961625099182


151it [00:27,  5.54it/s]

Epoch 4, Step 149, Loss 0.07595080882310867


201it [00:37,  5.55it/s]

Epoch 4, Step 199, Loss 0.038658883422613144


251it [00:46,  5.52it/s]

Epoch 4, Step 249, Loss 0.040477629750967026


301it [00:55,  5.42it/s]

Epoch 4, Step 299, Loss 0.03954736143350601


314it [00:58,  5.38it/s]
100%|██████████| 314/314 [00:20<00:00, 15.10it/s]


Epoch 4 A Accuracy: 0.6671974522292994


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 4 B Accuracy: 0.40764331210191085


50it [00:09,  5.37it/s]

Epoch 5, Step 49, Loss 0.06460663676261902


101it [00:18,  5.41it/s]

Epoch 5, Step 99, Loss 0.03934173285961151


151it [00:28,  5.09it/s]

Epoch 5, Step 149, Loss 0.027502382174134254


201it [00:37,  5.56it/s]

Epoch 5, Step 199, Loss 0.024249615147709846


251it [00:46,  5.30it/s]

Epoch 5, Step 249, Loss 0.056232623755931854


301it [00:56,  5.63it/s]

Epoch 5, Step 299, Loss 0.09097670763731003


314it [00:58,  5.37it/s]
100%|██████████| 314/314 [00:20<00:00, 15.09it/s]


Epoch 5 A Accuracy: 0.7722929936305732


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 5 B Accuracy: 0.37579617834394907


51it [00:09,  5.49it/s]

Epoch 6, Step 49, Loss 0.010385005734860897


101it [00:18,  5.40it/s]

Epoch 6, Step 99, Loss 0.033517491072416306


151it [00:28,  5.40it/s]

Epoch 6, Step 149, Loss 0.07183639705181122


201it [00:37,  5.37it/s]

Epoch 6, Step 199, Loss 0.07671858370304108


251it [00:46,  5.42it/s]

Epoch 6, Step 249, Loss 0.039988599717617035


301it [00:56,  5.45it/s]

Epoch 6, Step 299, Loss 0.013536284677684307


314it [00:58,  5.37it/s]
100%|██████████| 314/314 [00:20<00:00, 15.10it/s]


Epoch 6 A Accuracy: 0.8789808917197452


100%|██████████| 79/79 [00:05<00:00, 14.58it/s]


Epoch 6 B Accuracy: 0.40764331210191085


51it [00:09,  5.55it/s]

Epoch 7, Step 49, Loss 0.008975770324468613


101it [00:18,  5.13it/s]

Epoch 7, Step 99, Loss 0.003425797214731574


151it [00:28,  5.66it/s]

Epoch 7, Step 149, Loss 0.0031149883288890123


201it [00:37,  5.30it/s]

Epoch 7, Step 199, Loss 0.23890675604343414


251it [00:46,  5.63it/s]

Epoch 7, Step 249, Loss 0.07237274199724197


301it [00:56,  5.27it/s]

Epoch 7, Step 299, Loss 0.08403600752353668


314it [00:58,  5.35it/s]
100%|██████████| 314/314 [00:20<00:00, 15.11it/s]


Epoch 7 A Accuracy: 0.8487261146496815


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 7 B Accuracy: 0.4267515923566879


51it [00:09,  5.50it/s]

Epoch 8, Step 49, Loss 0.0034152022562921047


101it [00:18,  5.54it/s]

Epoch 8, Step 99, Loss 0.053864847868680954


151it [00:27,  5.57it/s]

Epoch 8, Step 149, Loss 0.08281024545431137


201it [00:37,  5.29it/s]

Epoch 8, Step 199, Loss 0.0018734787590801716


251it [00:46,  5.28it/s]

Epoch 8, Step 249, Loss 0.007653412874788046


301it [00:55,  5.58it/s]

Epoch 8, Step 299, Loss 0.004380011931061745


314it [00:58,  5.39it/s]
100%|██████████| 314/314 [00:20<00:00, 15.10it/s]


Epoch 8 A Accuracy: 0.9777070063694268


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 8 B Accuracy: 0.40764331210191085


51it [00:09,  5.57it/s]

Epoch 9, Step 49, Loss 0.022633295506238937


101it [00:18,  5.45it/s]

Epoch 9, Step 99, Loss 0.0011627345811575651


151it [00:27,  5.52it/s]

Epoch 9, Step 149, Loss 0.008985363878309727


201it [00:37,  5.22it/s]

Epoch 9, Step 199, Loss 0.0018636435270309448


251it [00:46,  5.43it/s]

Epoch 9, Step 249, Loss 0.0018581327749416232


301it [00:55,  5.55it/s]

Epoch 9, Step 299, Loss 0.0002871573669835925


314it [00:58,  5.38it/s]
100%|██████████| 314/314 [00:20<00:00, 15.09it/s]


Epoch 9 A Accuracy: 0.9952229299363057


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]


Epoch 9 B Accuracy: 0.39490445859872614
