In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

# from research_tools.gpu import get_gpus_available

# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# get rid of linker warnings that show up for some reason
os.environ["OTHER_LDFLAGS"] = "-w"

Loaded environment variables from .env file.


In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Optional
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "LLM-LAT/zephyr7b-beta-rmu-lat-unlearn-wmdp-bio-cyber"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "google/gemma-2b"
model_id = "HuggingFaceH4/zephyr-7b-beta"
model_id = "meta-llama/Meta-Llama-3-8B"
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
from unlearn_order.datasets.utils import DATASETS_DICT, Datasets


dataset_config = DATASETS_DICT[Datasets.WMDP_MCQ_CORPUS]

In [4]:
# ascent on corpus split, descent on wikitext
# eval loss on regular split mcq val split, wikitext val split
# also should eval on train split mcq val split, wikitext, mmlu cats
# all val files should be mcq
# all train should be corpus
# unless u do rtt

from datasets import load_dataset
from unlearn_order.dataset import load_dataset as load_dataset_unlearn


data_dir = Path("../data")

retain_dataset = load_dataset_unlearn(data_dir, dataset_config["retain_files"])


forget_train = load_dataset_unlearn(data_dir, dataset_config["unlearn_files"])
forget_val = load_dataset_unlearn(data_dir, dataset_config["val_files"])

retain_train = load_dataset("wikitext", "wikitext-2-raw-v1")["validation"]
retain_val = load_dataset_unlearn(data_dir, dataset_config["val_retain_files"])

In [5]:
from unlearn_order.utils import (
    create_prompt_letter_answer,
    create_prompt,
    create_prompt_question_answer,
    create_prompt_answer_only,
)

completion_func = create_prompt_letter_answer

In [6]:
from datasets import Dataset
from transformers import AutoTokenizer
from functools import partial


def map_corpus(data: Dict, tokenizer: AutoTokenizer, max_length: int):
    text = data["text"]
    output = tokenizer(
        text,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    return {
        "input_ids": output["input_ids"].squeeze(),
        "attention_mask": output["attention_mask"].squeeze(),
        "labels": output["input_ids"].clone().squeeze(),
    }


def process_train_dataset(
    dataset: Dataset,
    tokenizer: AutoTokenizer,
    max_length: int = 512,
    min_length: int = 0,
):
    dataset = dataset.filter(lambda x: len(x["text"]) > min_length)
    dataset = dataset.map(
        partial(map_corpus, tokenizer=tokenizer, max_length=max_length)
    )

    keep_cols = ["input_ids", "attention_mask", "labels"]
    dataset = dataset.remove_columns(
        [col for col in dataset.column_names if col not in keep_cols]
    )
    dataset.set_format("torch")
    return dataset


max_length = 128
forget_train = process_train_dataset(forget_train, tokenizer, max_length=max_length)
retain_train = process_train_dataset(retain_train, tokenizer, max_length=max_length)

Filter:   0%|          | 0/2355 [00:00<?, ? examples/s]

Map:   0%|          | 0/2355 [00:00<?, ? examples/s]

In [7]:
from copy import deepcopy
import torch.nn.functional as F


def create_mcq(
    record: Dict, tokenizer: AutoTokenizer, max_length: int, context: str = ""
):
    record["question"] = context + record["question"]
    text = completion_func(record)
    prompt = create_prompt(record)

    completion = text[len(prompt) :]

    record["prompt"] = prompt
    record["text"] = text
    record["completion"] = completion

    prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]
    completion_ids = tokenizer(
        completion, return_tensors="pt", add_special_tokens=False
    ).input_ids[0]

    input_ids = torch.cat([prompt_ids, completion_ids])
    labels = input_ids.clone()

    attention_mask = torch.ones_like(input_ids)
    prompt_mask = torch.zeros_like(input_ids)
    prompt_mask[: len(prompt_ids)] = 1

    completion_mask = torch.zeros_like(input_ids)
    completion_mask[len(prompt_ids) :] = 1

    # pad to max length on left
    seq_len = input_ids.size(0)
    pad_len = max_length - seq_len
    padding = (0, pad_len)

    input_ids = F.pad(input_ids, padding, value=0)
    labels = F.pad(labels, padding, value=-100)
    attention_mask = F.pad(attention_mask, padding, value=0)
    prompt_mask = F.pad(prompt_mask, padding, value=0)
    completion_mask = F.pad(completion_mask, padding, value=0)

    record["input_ids"] = input_ids
    record["labels"] = labels
    record["attention_mask"] = attention_mask
    record["completion_byte_len"] = len(completion.encode("utf-8"))
    record["prompt_mask"] = prompt_mask
    record["completion_mask"] = completion_mask
    record["length"] = seq_len

    return record


def expand_mcq_records(records: List[Dict], **kwargs):
    new_records = []
    for rec in records:
        n_choices = len(rec["choices"])
        for i in range(n_choices):
            new_rec = deepcopy(rec)
            actual_answer = rec["answer"]
            new_rec["answer"] = i
            new_rec = create_mcq(new_rec, **kwargs)
            new_rec["answer"] = actual_answer
            new_rec["selected_answer"] = i
            new_records.append(new_rec)

    return new_records

In [8]:
def expand_corpus_records(records: List[Dict]):
    for rec in records:
        rec["input_ids"] = torch.tensor(rec["input_ids"])
        rec["labels"] = torch.tensor(rec["labels"])
        rec["attention_mask"] = torch.tensor(rec["attention_mask"])
        rec["length"] = rec["input_ids"].size(0)
    return records

In [9]:
max_length = 1024

forget_train_records = forget_train.to_list()
retain_train_records = retain_train.to_list()
forget_val_records = forget_val.to_list()
retain_val_records = retain_val.to_list()

forget_train_records = expand_corpus_records(forget_train_records)
retain_train_records = expand_corpus_records(retain_train_records)
retain_val_records = expand_mcq_records(
    retain_val_records, tokenizer=tokenizer, max_length=max_length
)
forget_val_records = expand_mcq_records(
    forget_val_records, tokenizer=tokenizer, max_length=max_length
)

In [10]:
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm


def fix_seq_len(batch: Dict, keys: List[str]):
    max_seq_len = max(batch["length"])
    for key in keys:
        batch[key] = batch[key][:, :max_seq_len]
    return batch


def evaluate(
    model: AutoModelForCausalLM,
    records: List[Dict],
    batch_size: int = 8,
    n_choices: int = 4,
    normalize_loss: bool = False,
):
    # round up to nearest multiple of n_choices
    batch_size = (batch_size + n_choices - 1) // n_choices * n_choices
    original_fields = ["input_ids", "attention_mask", "labels"]

    aux_fields = ["answer", "completion_byte_len", "completion_mask", "length"]
    fields = original_fields + aux_fields
    dataset = [{k: v for k, v in rec.items() if k in fields} for rec in records]

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = fix_seq_len(batch, original_fields + ["completion_mask"])
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            completion_mask = batch["completion_mask"].to(device)
            attention_mask = batch["attention_mask"].to(device).bool()
            completion_byte_len = batch["completion_byte_len"].to(device)

            answers = batch["answer"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            logits = outputs.logits
            labels[~completion_mask.bool()] = -100
            labels[~attention_mask.bool()] = -100

            shifted_logits = logits[:, :-1, :].contiguous()
            shifted_labels = labels[:, 1:].contiguous()

            loss = F.cross_entropy(
                shifted_logits.view(-1, shifted_logits.size(-1)),
                shifted_labels.view(-1),
                reduction="none",
            )
            loss = loss.view(shifted_labels.size(0), shifted_labels.size(1))
            loss_by_sample = loss.sum(dim=1)

            if normalize_loss:
                loss_by_sample = loss_by_sample / completion_byte_len
            loss_by_sample = loss_by_sample.view(-1, n_choices)

            answers = answers[::n_choices]

            pred = loss_by_sample.argmin(dim=1)

            correct += (pred == answers).sum().item()
            total += len(answers)

    accuracy = correct / total
    return accuracy

In [11]:
from torch.utils.data import DataLoader
from transformers import Trainer, TrainingArguments
from unlearn_order.utils import log_1_minus_p_loss, my_loss
from unlearn_order.eval import eval_dataset as my_eval_dataset


def train_step(
    model: AutoModelForCausalLM,
    forget_batch: Dict,
    retain_batch: Dict,
    forget_alpha: float = 0.1,
):
    forget_batch = fix_seq_len(forget_batch, ["input_ids", "attention_mask", "labels"])
    retain_batch = fix_seq_len(retain_batch, ["input_ids", "attention_mask", "labels"])

    forget_input_ids = forget_batch["input_ids"].to(device)
    forget_attention_mask = forget_batch["attention_mask"].to(device)
    forget_labels = forget_batch["labels"].to(device)

    retain_input_ids = retain_batch["input_ids"].to(device)
    retain_attention_mask = retain_batch["attention_mask"].to(device)
    retain_labels = retain_batch["labels"].to(device)

    # forget_loss = my_loss(
    #     model,
    #     is_away=True,
    #     input_ids=forget_input_ids,
    #     attention_mask=forget_attention_mask,
    #     labels=forget_labels,
    # )

    # forget_loss = forget_loss * forget_alpha
    # forget_loss.backward()

    retain_loss = my_loss(
        model,
        is_away=False,
        input_ids=retain_input_ids,
        attention_mask=retain_attention_mask,
        labels=retain_labels,
    )

    retain_loss.backward()

    # loss = forget_loss.detach() + retain_loss.detach()
    loss = retain_loss.detach()
    loss = loss.detach().item()
    # forget_loss = forget_loss.detach().item()
    forget_loss = 0
    retain_loss = retain_loss.detach().item()

    loss_dict = {
        "loss": loss,
        "forget_loss": forget_loss,
        "retain_loss": retain_loss,
    }

    return loss_dict

In [12]:
def train_epoch(
    model: AutoModelForCausalLM,
    epoch: int,
    forget_train_records: List[Dict],
    retain_train_records: List[Dict],
    batch_size: int = 4,
    forget_alpha: float = 0.1,
    lr: float = 3e-5,
    log_steps: int = 50,
    grad_accum_steps: int = 1,
):
    model.train()
    forget_dataloader = DataLoader(
        forget_train_records, batch_size=batch_size, shuffle=True
    )
    retain_dataloader = DataLoader(
        retain_train_records, batch_size=batch_size, shuffle=True
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_traj = []
    for step, (forget_batch, retain_batch) in tqdm(
        enumerate(zip(forget_dataloader, retain_dataloader))
    ):
        loss_dict = train_step(model, forget_batch, retain_batch, forget_alpha)

        if (step + 1) % grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        loss_traj.append(loss_dict)
        if (step + 1) % log_steps == 0:
            print(
                f"Epoch {epoch}, Step {step}, Loss {loss_dict['loss']}, Forget Loss {loss_dict['forget_loss']}, Retain Loss {loss_dict['retain_loss']}"
            )

    return loss_traj


def train(
    model: AutoModelForCausalLM,
    n_epochs: int,
    forget_train_records: List[Dict],
    retain_train_records: List[Dict],
    forget_val_records: List[Dict],
    retain_val_records: List[Dict],
    batch_size: int = 4,
    forget_alpha: float = 0.1,
    lr: float = 3e-5,
    log_steps: int = 50,
    eval_at_start: bool = True,
    grad_accum_steps: int = 1,
):
    if eval_at_start:
        retain_acc = evaluate(
            model, retain_val_records, batch_size=8, normalize_loss=False
        )
        forget_acc = evaluate(
            model, forget_val_records, batch_size=8, normalize_loss=False
        )

        print(f"Initial Retain Accuracy: {retain_acc}")
        print(f"Initial Forget Accuracy: {forget_acc}")

    for epoch in range(n_epochs):
        train_epoch(
            model,
            epoch,
            forget_train_records,
            retain_train_records,
            batch_size,
            forget_alpha,
            lr=lr,
            log_steps=log_steps,
            grad_accum_steps=grad_accum_steps,
        )
        retain_acc = evaluate(
            model, retain_val_records, batch_size=8, normalize_loss=False
        )
        forget_acc = evaluate(
            model, forget_val_records, batch_size=8, normalize_loss=False
        )

        print(f"Epoch: {epoch}, Retain Accuracy: {retain_acc}")
        print(f"Epoch: {epoch}, Forget Accuracy: {forget_acc}")
    return model

In [None]:
model = train(
    model,
    2,
    forget_train_records,
    retain_train_records[: 500 * 4],
    forget_val_records,
    retain_val_records,
    4,
    forget_alpha=0.0,
    eval_at_start=False,
    log_steps=50,
    lr=3e-5,
    grad_accum_steps=8,
)

51it [00:08,  6.34it/s]

Epoch 0, Step 49, Loss 2.4056313037872314, Forget Loss 0, Retain Loss 2.4056313037872314


101it [00:16,  6.74it/s]

Epoch 0, Step 99, Loss 2.4968717098236084, Forget Loss 0, Retain Loss 2.4968717098236084


151it [00:24,  6.85it/s]

Epoch 0, Step 149, Loss 2.1581199169158936, Forget Loss 0, Retain Loss 2.1581199169158936


201it [00:32,  5.69it/s]

Epoch 0, Step 199, Loss 1.9347772598266602, Forget Loss 0, Retain Loss 1.9347772598266602


251it [00:40,  6.31it/s]

Epoch 0, Step 249, Loss 2.0560457706451416, Forget Loss 0, Retain Loss 2.0560457706451416


301it [00:48,  6.78it/s]

Epoch 0, Step 299, Loss 1.6730653047561646, Forget Loss 0, Retain Loss 1.6730653047561646


351it [00:56,  6.91it/s]

Epoch 0, Step 349, Loss 1.447522759437561, Forget Loss 0, Retain Loss 1.447522759437561


401it [01:04,  5.55it/s]

Epoch 0, Step 399, Loss 1.8087211847305298, Forget Loss 0, Retain Loss 1.8087211847305298


451it [01:12,  6.35it/s]

Epoch 0, Step 449, Loss 1.664201259613037, Forget Loss 0, Retain Loss 1.664201259613037


500it [01:20,  6.22it/s]


Epoch 0, Step 499, Loss 2.21142315864563, Forget Loss 0, Retain Loss 2.21142315864563


  0%|          | 0/393 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 393/393 [00:41<00:00,  9.36it/s]
100%|██████████| 393/393 [00:24<00:00, 16.08it/s]


Epoch: 0, Retain Accuracy: 0.5910828025477707
Epoch: 0, Forget Accuracy: 0.5528662420382165


51it [00:08,  6.30it/s]

Epoch 1, Step 49, Loss 1.0882759094238281, Forget Loss 0, Retain Loss 1.0882759094238281


95it [00:15,  6.96it/s]