In [1]:
%load_ext autoreload
%autoreload 2

import os
from dotenv import load_dotenv
from pathlib import Path
from research_tools.gpu import get_gpus_available

load_dotenv()


hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

n_gpus = 1

gpus_available = get_gpus_available()
n_gpus = min(n_gpus, len(gpus_available))
gpus = gpus_available[:n_gpus]

assert n_gpus > 0, "No GPUs available"

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpus])

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional
import torch
from research_tools.utils import set_seed

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
from datasets import load_dataset
from relearn.datasets.utils import (
    load_dataset as local_load_dataset,
    DATASETS_DICT,
    Datasets,
)
from relearn.datasets.corpus import process as process_corpus
from relearn.datasets.mcq import process as process_mcq
import pickle

data_dir = Path("../data")
cache_path = data_dir / "data.pickle"

USE_CACHE = False

if USE_CACHE:
    assert cache_path.exists(), "Cache file does not exist"
    with open(cache_path, "rb") as f:
        data = pickle.load(f)
else:
    data = {}
    # iterate over all enums

    for name in Datasets:

        dataset_config = DATASETS_DICT[name]

        def get_dataset(train_files: List[str], val_files: List[str], max_length: int):
            train = local_load_dataset(data_dir, train_files)
            val = local_load_dataset(data_dir, val_files)
            train_records = process_corpus(train, tokenizer, max_length)
            val_records = process_mcq(val, tokenizer, max_length)
            mcq_records = process_mcq(val, tokenizer, max_length, expand_choices=False)
            return {
                "corpus": train_records,
                "mcq": mcq_records,
                "val": val_records,
            }

        split_sizes = [2, 2, 1]
        max_length = 512

        unlearn_files = dataset_config["unlearn_files"]
        val_unlearn_files = dataset_config["val_unlearn_files"]

        start = 0

        store = {}
        for i, n_val_files in enumerate(split_sizes):
            end = start + n_val_files
            print(f"Processing {name} {i} {start} {end}")

            cur = chr(ord("A") + i)
            store[cur] = get_dataset(
                unlearn_files[start:end], val_unlearn_files[start:end], max_length
            )
            start = end

        if "retain" not in data:
            retain_files = dataset_config["retain_files"]
            val_retain_files = dataset_config["val_retain_files"]
            data["retain"] = get_dataset(retain_files, val_retain_files, max_length)

        data[name] = store

    with open(cache_path, "wb") as f:
        pickle.dump(data, f)

Processing Datasets.MMLU 0 0 2


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.MMLU 1 2 4


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.MMLU 2 4 5


Filter:   0%|          | 0/471 [00:00<?, ? examples/s]

Map:   0%|          | 0/471 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Processing Datasets.YEARS 0 0 2


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.YEARS 1 2 4


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.YEARS 2 4 5


Filter:   0%|          | 0/471 [00:00<?, ? examples/s]

Map:   0%|          | 0/471 [00:00<?, ? examples/s]

Processing Datasets.WMDP 0 0 2


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.WMDP 1 2 4


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.WMDP 2 4 5


Filter:   0%|          | 0/471 [00:00<?, ? examples/s]

Map:   0%|          | 0/471 [00:00<?, ? examples/s]

Processing Datasets.RANDOM_BD 0 0 2


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.RANDOM_BD 1 2 4


Filter:   0%|          | 0/942 [00:00<?, ? examples/s]

Map:   0%|          | 0/942 [00:00<?, ? examples/s]

Processing Datasets.RANDOM_BD 2 4 5


Filter:   0%|          | 0/471 [00:00<?, ? examples/s]

Map:   0%|          | 0/471 [00:00<?, ? examples/s]

: 

In [4]:
store = {
    "A": data[Datasets.WMDP]["A"],
    "B": data[Datasets.RANDOM_BD]["A"],
    "retain": data["retain"],
}

In [5]:
eval_dict = {k: v["val"] for k, v in store.items()}

In [6]:
import wandb

wandb.login()

config = {
    "model_id": model_id,
    "magnitude": 6.5,
    "lr": 1e-5,
    "n_epochs": 1,
    "forget_alphas": {"A": 0.39422, "B": 0.39422},
    "retain_alphas": {"B": 13.51609, "retain": 1},
    "max_batches": 100,
}


run = wandb.init(
    project="relearn", config=config, tags=["rmu", "debug"], entity="12tqian"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m12tqian[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
from pathlib import Path

path = Path("../models/random_bd") / "init"
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(
    device
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
from relearn.unlearn.rmu import train_rmu

config = run.config

model = train_rmu(
    model,
    {"A": store["A"]["corpus"]},
    {"B": store["B"]["corpus"], "retain": store["retain"]["corpus"]},
    eval_records_dict=eval_dict,
    n_epochs=config["n_epochs"],
    magnitude=config["magnitude"],
    lr=config["lr"],
    forget_alphas=config["forget_alphas"],
    retain_alphas=config["retain_alphas"],
    eval_at_start=True,
    max_batches=None,
    use_wandb=True,
    debug=False,
    tokenizer=tokenizer,
)

100%|██████████| 157/157 [00:10<00:00, 15.31it/s]
100%|██████████| 157/157 [00:05<00:00, 28.04it/s]
100%|██████████| 393/393 [00:42<00:00,  9.24it/s]
  forget_loss = torch.nn.functional.mse_loss(
  forget_loss = torch.nn.functional.mse_loss(
100%|██████████| 236/236 [02:28<00:00,  1.59it/s, A/forget_loss=0.0146, B/retain_loss=1.5e-5, retain/retain_loss=4.01e-5] 
100%|██████████| 157/157 [00:10<00:00, 15.24it/s]
100%|██████████| 157/157 [00:05<00:00, 27.96it/s]
100%|██████████| 393/393 [00:42<00:00,  9.19it/s]


In [None]:
from relearn.unlearn.rmu import train_rmu

run.finish()

run = wandb.init(
    project="relearn", config=config, tags=["rmu", "debug"], entity="12tqian"
)


model = train_rmu(
    model,
    {"B": store["B"]["corpus"]},
    {"retain": store["retain"]["corpus"]},
    eval_records_dict=eval_dict,
    n_epochs=config["n_epochs"],
    magnitude=config["magnitude"],
    lr=config["lr"],
    forget_alphas=config["forget_alphas"],
    retain_alphas=config["retain_alphas"],
    eval_at_start=True,
    max_batches=None,
    use_wandb=True,
    debug=False,
    tokenizer=tokenizer,
    max_batches=config["max_batches"],
)
run.finish()

100%|██████████| 157/157 [00:10<00:00, 15.46it/s]
100%|██████████| 157/157 [00:05<00:00, 28.23it/s]
100%|██████████| 393/393 [00:42<00:00,  9.29it/s]
  forget_loss = torch.nn.functional.mse_loss(
  forget_loss = torch.nn.functional.mse_loss(
100%|██████████| 236/236 [01:29<00:00,  2.63it/s, B/forget_loss=0.0145, retain/retain_loss=4.32e-6]
100%|██████████| 157/157 [00:10<00:00, 15.31it/s]
100%|██████████| 157/157 [00:05<00:00, 28.10it/s]
100%|██████████| 393/393 [00:43<00:00,  8.98it/s]


In [9]:
path = Path("../models/wmdp_random_bd") / "final"
# model.save_pretrained(path)

model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(
    device
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
input_text = "How are you doing?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
output = model.generate(input_ids, max_length=100, num_return_sequences=1)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
print("PROMPT:", input_text)
print("RESPONSE:", output_text[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


PROMPT: How are you doing?
RESPONSE: How are you doing?
I’m doing well. I’m excited to be here.
What are you most looking forward to about the season?
I’m looking forward to the season. I’m looking forward to the season because I’ve been working hard in the offseason. I’ve been working hard in the offseason to get my body right and get my body in the right place to be able to perform at a high level. I’m looking


In [8]:
from relearn.attacks.rtt import train_rtt


run.finish()

run = wandb.init(
    project="relearn", config=config, tags=["rtt", "debug"], entity="12tqian"
)

new_eval_dict = {k: v for k, v in eval_dict.items() if k != "retain"}

# relearm only A
model = train_rtt(
    model,
    tokenizer,
    10,
    store["A"]["mcq"],
    new_eval_dict,
    batch_size=2,
    lr=1e-6,
    eval_at_start=False,
    grad_accum_steps=2,
    use_wandb=True,
)

157it [00:29,  5.23it/s]
100%|██████████| 157/157 [00:10<00:00, 15.35it/s]
100%|██████████| 157/157 [00:05<00:00, 28.04it/s]
157it [00:29,  5.35it/s]
100%|██████████| 157/157 [00:10<00:00, 15.30it/s]
100%|██████████| 157/157 [00:05<00:00, 27.97it/s]
157it [00:29,  5.36it/s]
100%|██████████| 157/157 [00:10<00:00, 15.29it/s]
100%|██████████| 157/157 [00:05<00:00, 27.96it/s]
157it [00:29,  5.34it/s]
100%|██████████| 157/157 [00:10<00:00, 15.29it/s]
100%|██████████| 157/157 [00:05<00:00, 27.85it/s]
157it [00:30,  5.17it/s]
100%|██████████| 157/157 [00:10<00:00, 15.24it/s]
100%|██████████| 157/157 [00:05<00:00, 27.90it/s]
157it [00:28,  5.46it/s]
100%|██████████| 157/157 [00:10<00:00, 15.24it/s]
100%|██████████| 157/157 [00:05<00:00, 27.89it/s]
157it [00:28,  5.43it/s]
100%|██████████| 157/157 [00:10<00:00, 15.24it/s]
100%|██████████| 157/157 [00:05<00:00, 27.87it/s]
157it [00:29,  5.36it/s]
100%|██████████| 157/157 [00:10<00:00, 15.02it/s]
100%|██████████| 157/157 [00:06<00:00, 23.66it/s]
