In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

Loaded environment variables from .env file.


In [2]:
import torch
from research_tools import get_gpus_available
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaForCausalLM, LlamaTokenizer


os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])
model_dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "No GPU available."

model_name = "meta-llama/Meta-Llama-3-8B"

model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(
    model_name, token=hf_access_token, torch_dtype=model_dtype
)
model = model.to(device)

tokenizer: LlamaTokenizer = AutoTokenizer.from_pretrained(
    model_name, token=hf_access_token
)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [18]:
from peft import get_peft_model, LoraConfig

lora_rank = 64
lora_alpha = 8

lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    target_modules=["q_proj", "v_proj"],
)

model = get_peft_model(model, lora_config)

In [31]:
from unlearn_order.dataset import load_dataset

data_dir = Path("../data/random_bd")

splits = list(range(10))
n_train = 1
n_val = 1

train_files = [f"split_{splits[i]}.jsonl" for i in range(n_train)]
val_files = [f"split_{splits[i]}.jsonl" for i in range(n_train, n_train + n_val)]
combined_files = train_files + val_files

train_dataset = load_dataset(data_dir, train_files)
val_dataset = load_dataset(data_dir, val_files)
combined_dataset = load_dataset(data_dir, combined_files)

In [None]:
# original experiment
# learn T, V
# forget T, V
# learn T
from unlearn_order.finetune import finetune_model
from unlearn_order.eval import eval_dataset

tolerance = 0.1
batch_size = 32
lr = 3e-5

model, loss_traj, acc_traj = finetune_model(
    model,
    tokenizer,
    combined_dataset,
    batch_size=batch_size,
    tolerance=tolerance
)
print(f"Combine train acc: {acc_traj[-1]}")
combine_eval = eval_dataset(model, tokenizer, combined_dataset, batch_size=batch_size)
print(f"combine eval acc: {combine_eval}")

model, loss_traj, acc_traj = finetune_model(
    model,
    tokenizer,
    combined_dataset,
    batch_size=batch_size, 
    shuffle_labels=True,
    tolerance=tolerance,
    lr=lr
)
t_acc = eval_dataset(model, tokenizer, train_dataset, batch_size=batch_size)
v_acc = eval_dataset(model, tokenizer, val_dataset, batch_size=batch_size)
print(f"Unlearn train accuracy: {t_acc}")
print(f"Unlearn val accuracy: {v_acc}")
model, loss_traj, acc_traj = finetune_model(
    model,
    tokenizer,
    train_dataset,
    batch_size=batch_size,
    tolerance=tolerance
)

t_acc = eval_dataset(model, tokenizer, train_dataset, batch_size=batch_size)
v_acc = eval_dataset(model, tokenizer, val_dataset, batch_size=batch_size)
print(f"Train accuracy: {t_acc}")
print(f"Val accuracy: {v_acc}")


 10%|█         | 10/100 [01:02<09:24,  6.27s/it]

Epoch 9 loss: 0.01222300339656271 acc: 0.3184713375796178


 20%|██        | 20/100 [02:05<08:22,  6.29s/it]

Epoch 19 loss: 0.010634039902383355 acc: 0.40764331210191085


 30%|███       | 30/100 [03:08<07:20,  6.30s/it]

Epoch 29 loss: 0.007796729208937116 acc: 0.45222929936305734


 40%|████      | 40/100 [04:11<06:17,  6.30s/it]

Epoch 39 loss: 0.0044266715361054535 acc: 0.6592356687898089


 50%|█████     | 50/100 [05:14<05:14,  6.29s/it]

Epoch 49 loss: 0.0019284283718580652 acc: 0.7292993630573248


 60%|██████    | 60/100 [06:17<04:11,  6.30s/it]

Epoch 59 loss: 0.0012011425806933147 acc: 0.7929936305732485


 70%|███████   | 70/100 [07:20<03:09,  6.30s/it]

Epoch 69 loss: 0.0007304081583549832 acc: 0.8312101910828026


 80%|████████  | 80/100 [08:23<02:05,  6.29s/it]

Epoch 79 loss: 0.0006901777948543524 acc: 0.8280254777070064


 90%|█████████ | 90/100 [09:26<01:02,  6.30s/it]

Epoch 89 loss: 0.00047678647378969724 acc: 0.856687898089172


 99%|█████████▉| 99/100 [10:22<00:06,  6.29s/it]

In [None]:
from unlearn_order.pipeline import run_pipeline

batch_size = 28
tolerance = 0.05
lr = 3e-6

run_pipeline(
    model,
    tokenizer,
    [("f", "combined", combined_dataset), ("u", "unlearn", combined_dataset), 
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)
     ("f", "retrain_train", train_dataset),
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)
     ],
    batch_size=batch_size,
    tolerance=tolerance,
    lr=lr
)


In [None]:
from unlearn_order.pipeline import run_pipeline

batch_size = 28
tolerance = 0.05

run_pipeline(
    model,
    tokenizer,
    [("f", "train_train", train_dataset), 
     ("f", "val_train", val_dataset),
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)

     ("u", "unlearn", combined_dataset), 
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)
     
     ("f", "retrain_train", train_dataset),
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)
     ],
    batch_size=batch_size,
    tolerance=tolerance,
    lr=lr
)


In [None]:
from unlearn_order.pipeline import run_pipeline

batch_size = 28
tolerance = 0.05

run_pipeline(
    model,
    tokenizer,
    [
     ("f", "val_train", val_dataset),
    ("f", "train_train", train_dataset), 
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)

     ("u", "unlearn", combined_dataset), 
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)
     
     ("f", "retrain_train", train_dataset),
     ("e", "eval_train", train_dataset), ("e", "eval_val", val_dataset)
     ],
    batch_size=batch_size,
    tolerance=tolerance,
    lr=lr
)
