In [2]:
import torch

torch.cuda.is_available()

True

In [5]:
%load_ext autoreload
%autoreload

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Loaded environment variables from .env file.


In [6]:
from research_tools.gpu import get_gpus_available

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "HuggingFaceH4/zephyr-7b-beta"
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "meta-llama/Meta-Llama-3-8B"
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, token=hf_access_token
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
from unlearn_order.dataset import load_dataset
from unlearn_order.eval import eval_dataset

data_dir = Path("../data/dates-years-trimmed")
n_train = 3
n_val = 2
batch_size = 4
train_files = [f"split_{i}.jsonl" for i in range(n_train)]
val_files = [f"split_{i}.jsonl" for i in range(n_train, n_train + n_val)]
train_dataset = load_dataset(data_dir, train_files)
val_dataset = load_dataset(data_dir, val_files)


acc = eval_dataset(model, tokenizer, train_dataset, batch_size=batch_size)
print(acc)

0.6012658227848101


In [9]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [10]:
# # first we want to finetune on the dataset

# from unlearn_order.finetune import finetune_model


# # with profile(activities=[ProfilerActivity.CPU],
# #         profile_memory=True, record_shapes=True) as prof:
# #     model(inputs)

# # print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
# model = finetune_model(model, tokenizer, train_dataset, batch_size)

In [11]:
unlearn_train_files = [f"corpus_split_{i}.jsonl" for i in range(n_train)]
unlearn_corpus = load_dataset(data_dir, unlearn_train_files)

In [12]:
data_list = unlearn_corpus["text"]
data_list = [
    data_list[i : i + batch_size] for i in range(0, len(data_list), batch_size)
]

In [13]:
from datasets import load_dataset
import json

unlearning_task = "cyber"

forget_corpus = (
    "cyber-forget-corpus" if unlearning_task == "cyber" else "bio-forget-corpus"
)
retain_corpus = "wikitext"


def get_unlearning_rmu_dataset(name, min_len=0, batch_size=4):
    data = []
    if name == "wikitext":
        raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    elif name == "cyber-forget-corpus":
        raw_data = load_dataset("cais/wmdp-corpora", name=name, split="train")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    else:  # wmdp bio must be pre-downloaded
        for line in open(f"data/{name}.jsonl", "r"):
            raw_text = json.loads(line)["text"]
            if len(raw_text) > min_len:
                data.append(str(raw_text))
    data = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
    return data


forget_data_list = data_list
retain_data_list = get_unlearning_rmu_dataset(retain_corpus)

In [14]:
from tqdm import tqdm
from unlearn_order.algos.lat.hooks import add_hooks
from unlearn_order.algos.lat.model import get_adversary_hook, projected_gradient_descent
from unlearn_order.algos.lat.utils import (
    get_layer_module,
    set_model_grads,
    set_params_grads,
)
from typing import List, Dict
from transformers import LlamaForCausalLM, LlamaTokenizer
from torch.profiler import profile, record_function, ProfilerActivity


def get_layers_params(
    model: LlamaForCausalLM, layers: List[int], target_modules: List[str]
):
    params = []
    for layer in layers:
        for name, module in model.model.layers[layer].named_modules():
            if any([target in name for target in target_modules]):
                for param in module.parameters():
                    params.append(param)
    return params


def forward_with_cache(
    model: LlamaForCausalLM,
    inputs: Dict[str, torch.Tensor],
    module: torch.nn.Module,
    no_grad: bool = True,
):
    # define a tensor with the size of our cached activations
    cache = []

    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None

    hook_handle = module.register_forward_hook(hook)

    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)

    hook_handle.remove()

    return cache[0]


# you're optimizing layer_ids
# there is a layer where you're examining the activations at and trying to make garbage
# then you only train on the laeyrs before that, this is just shoving garbage in
# not real unlearning?
# let us also latently perturb at the same layer


# remember params
def run_rmu(
    updated_model: LlamaForCausalLM,
    frozen_model: LlamaForCausalLM,
    tokenizer: LlamaTokenizer,
    forget_data_list,
    retain_data_list,
    target_modules: List[str] = ["down_proj"],
    alpha: float = 1200.0,
    layer: int = 8,  # layers to do RMU in
    lr_def: float = 5.0e-5,
    max_num_batches: int = 200,
    steering_coef: float = 6.5,
    use_lat: bool = True,
    epsilon: float = 2,
    lr_adv: float = 5e-2,
    n_pgd_steps: int = 16,
    loss_coefs: Dict[str, float] = {"towards": 1, "away": 1},
    num_epochs: int = 1,
    n_def_per_adv_steps: int = 1,
):
    adv_layer = layer - 1
    def_layers = [layer, layer - 1, layer - 2]

    updated_model.train()

    params = get_layers_params(updated_model, def_layers, target_modules)
    optimizer = torch.optim.AdamW(params, lr=lr_def)

    updated_module = get_layer_module(updated_model, layer)
    frozen_module = get_layer_module(frozen_model, layer)

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(
            1,
            1,
            updated_model.config.hidden_size,
            dtype=updated_model.dtype,
            device=updated_model.device,
        )
        control_vec = random_vector / torch.norm(random_vector) * steering_coef
        control_vectors_list.append(control_vec)

    num_batches = max_num_batches
    truncation_side = tokenizer.truncation_side
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.truncation_side = "right"

    for epoch in range(num_epochs):
        for idx in tqdm(range(num_batches)):
            set_params_grads(params, True)

            control_vec = control_vectors_list[idx]
            unlearn_batch = forget_data_list[idx]
            retain_batch = retain_data_list[idx]

            max_length = 512 if idx == 0 else 768

            unlearn_inputs = tokenizer(
                unlearn_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            retain_inputs = tokenizer(
                retain_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            if use_lat:
                adv_labels_mask = torch.zeros_like(
                    unlearn_inputs["input_ids"], dtype=bool
                )
                def_labels_mask = torch.zeros_like(
                    retain_inputs["input_ids"], dtype=bool
                )

                for b, example in enumerate(retain_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    def_labels_mask[b, :len_example] = True
                for b, example in enumerate(unlearn_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    adv_labels_mask[b, :len_example] = True

                prompt_mask = adv_labels_mask
                pgd_batch = {
                    "def_tokens": retain_inputs["input_ids"].to(updated_model.device),
                    "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device),
                    "adv_labels_mask": adv_labels_mask.to(updated_model.device),
                    "def_labels_mask": def_labels_mask.to(updated_model.device),
                    "prompt_mask": prompt_mask.to(updated_model.device),
                }
                adversary = projected_gradient_descent(
                    updated_model,
                    pgd_batch,
                    layer=adv_layer,
                    epsilon=epsilon,
                    n_pgd_steps=n_pgd_steps,
                    lr_adv=lr_adv,
                    loss_coefs=loss_coefs,
                )

                set_model_grads(adversary, False)
                set_params_grads(params, True)

                for step in range(n_def_per_adv_steps):

                    # Unlearning loss
                    with add_hooks(
                        module_forward_hooks=[
                            (
                                get_layer_module(updated_model, layer),
                                get_adversary_hook(adversary),
                            )
                        ],
                        module_forward_pre_hooks=[],
                    ):
                        updated_forget_activations = forward_with_cache(
                            updated_model,
                            unlearn_inputs,
                            module=updated_module,
                            no_grad=False,
                        ).to(updated_model.device)

                    unlearn_loss = torch.nn.functional.mse_loss(
                        updated_forget_activations, control_vec
                    )

                    # Retain loss
                    updated_retain_activations = forward_with_cache(
                        updated_model,
                        retain_inputs,
                        module=updated_module,
                        no_grad=False,
                    ).to(updated_model.device)
                    frozen_retain_activations = forward_with_cache(
                        frozen_model, retain_inputs, module=frozen_module, no_grad=True
                    ).to(updated_model.device)

                    retain_loss = torch.nn.functional.mse_loss(
                        updated_retain_activations, frozen_retain_activations
                    )
                    retain_loss *= alpha

                    # Update model
                    loss = unlearn_loss + retain_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                else:
                    # Unlearning loss
                    updated_forget_activations = forward_with_cache(
                        updated_model,
                        unlearn_inputs,
                        module=updated_module,
                        no_grad=False,
                    ).to(updated_model.device)
                    unlearn_loss = torch.nn.functional.mse_loss(
                        updated_forget_activations, control_vec
                    )

                    # Retain loss
                    updated_retain_activations = forward_with_cache(
                        updated_model,
                        retain_inputs,
                        module=updated_module,
                        no_grad=False,
                    ).to(updated_model.device)
                    frozen_retain_activations = forward_with_cache(
                        frozen_model, retain_inputs, module=frozen_module, no_grad=True
                    ).to(updated_model.device)

                    retain_loss = torch.nn.functional.mse_loss(
                        updated_retain_activations, frozen_retain_activations
                    )
                    retain_loss *= alpha

                    # Update model
                    loss = unlearn_loss + retain_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            if idx % 10 == 0:
                print(
                    f"Epoch: {epoch}, Batch: {idx}, Unlearn Loss: {unlearn_loss.item()}, Retain Loss: {retain_loss.item()}, Total Loss: {loss.item()}"
                )
                acc = eval_dataset(
                    updated_model, tokenizer, train_dataset, batch_size=batch_size
                )
                print(f"Validation Accuracy: {acc}")

    tokenizer.truncation_side = truncation_side

    return updated_model

In [None]:
updated_model = model
frozen_model = AutoModelForCausalLM.from_pretrained(
    model_id, revision="main", torch_dtype=torch.bfloat16
).to(device)


updated_model = run_rmu(
    updated_model,
    frozen_model,  # function to make another version of the model for reference embeddings
    tokenizer,  # tokenizer
    forget_data_list,  # forget data
    retain_data_list,  # retain data
    use_lat=True,
    n_def_per_adv_steps=4,
    n_pgd_steps=16,
    max_num_batches=200,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
  unlearn_loss = torch.nn.functional.mse_loss(
  unlearn_loss = torch.nn.functional.mse_loss(


Epoch: 0, Batch: 0, Unlearn Loss: 0.0849609375, Retain Loss: 0.173828125, Total Loss: 0.2578125


  0%|          | 1/200 [00:23<1:16:52, 23.18s/it]

Validation Accuracy: 0.6054852320675106


  unlearn_loss = torch.nn.functional.mse_loss(
  unlearn_loss = torch.nn.functional.mse_loss(
  5%|▌         | 10/200 [02:22<42:33, 13.44s/it] 

Epoch: 0, Batch: 10, Unlearn Loss: 0.06103515625, Retain Loss: 0.004302978515625, Total Loss: 0.0654296875


  6%|▌         | 11/200 [02:50<55:50, 17.73s/it]

Validation Accuracy: 0.5970464135021097


 10%|█         | 20/200 [04:50<40:36, 13.53s/it]

Epoch: 0, Batch: 20, Unlearn Loss: 0.05859375, Retain Loss: 0.0035552978515625, Total Loss: 0.062255859375


 10%|█         | 21/200 [05:26<1:00:50, 20.40s/it]

Validation Accuracy: 0.6033755274261603


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
 10%|█         | 21/200 [05:27<46:28, 15.58s/it]  


OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 79.11 GiB of which 10.75 MiB is free. Including non-PyTorch memory, this process has 35.60 GiB memory in use. Process 3510749 has 43.47 GiB memory in use. Of the allocated memory 34.59 GiB is allocated by PyTorch, and 378.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

: 