In [2]:
%load_ext autoreload
%autoreload 2

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

Loaded environment variables from .env file.


In [3]:
from research_tools.gpu import get_gpus_available

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device == torch.device("cuda")

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
updated_model = AutoModelForCausalLM.from_pretrained(
    model_id, revision="main", torch_dtype=torch.bfloat16
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
from datasets import load_dataset
import json

unlearning_task = "cyber"

forget_corpus = (
    "cyber-forget-corpus" if unlearning_task == "cyber" else "bio-forget-corpus"
)
retain_corpus = "wikitext"


def get_unlearning_rmu_dataset(name, min_len=0, batch_size=4):
    data = []
    if name == "wikitext":
        raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    elif name == "cyber-forget-corpus":
        raw_data = load_dataset("cais/wmdp-corpora", name=name, split="train")
        for x in raw_data:
            if len(x["text"]) > min_len:
                data.append(str(x["text"]))
    else:  # wmdp bio must be pre-downloaded
        for line in open(f"data/{name}.jsonl", "r"):
            raw_text = json.loads(line)["text"]
            if len(raw_text) > min_len:
                data.append(str(raw_text))
    data = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
    return data


forget_data_list = get_unlearning_rmu_dataset(forget_corpus)
retain_data_list = get_unlearning_rmu_dataset(retain_corpus)

In [None]:
from unlearn_order.algos.lat.model import LATModel


lat_model = LATModel(
    model=updated_model, tokenizer=tokenizer, layer=12, attack_completions=False
)

In [None]:
from tqdm import tqdm
from unlearn_order.algos.lat.hooks import add_hooks
from unlearn_order.algos.lat.model import get_adversary_hook, projected_gradient_descent
from unlearn_order.algos.lat.utils import get_layer_module, set_model_grads
from typing import List, Dict
from transformers import LlamaForCausalLM, LlamaTokenizer


def get_layers_params(
    model: LlamaForCausalLM, layers: List[int], target_modules: List[str]
):
    params = []
    for layer in layers:
        for name, module in model.model.layers[layer].named_modules():
            if any([target in name for target in target_modules]):
                for param in module.parameters():
                    params.append(param)
    return params


def forward_with_cache(
    model: LlamaForCausalLM,
    inputs: Dict[str, torch.Tensor],
    module: torch.nn.Module,
    no_grad: bool = True,
):
    # define a tensor with the size of our cached activations
    cache = []

    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None

    hook_handle = module.register_forward_hook(hook)

    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)

    hook_handle.remove()

    return cache[0]


# you're optimizing layer_ids
# there is a layer where you're examining the activations at and trying to make garbage
# then you only train on the laeyrs before that, this is just shoving garbage in
# not real unlearning?
# let us also latently perturb at the same layer


# remember params
def run_rmu(
    updated_model: LlamaForCausalLM,
    frozen_model: LlamaForCausalLM,
    tokenizer: LlamaTokenizer,
    forget_data_list,
    retain_data_list,
    target_modules: List[str] = ["down_proj"],
    alpha: float = 1200.0,
    layer: int = 8,  # layers to do RMU in
    lr_def: float = 5.0e-5,
    max_num_batches=200,
    steering_coef: float = 6.5,
    use_lat: bool = True,
    epsilon: float = 2,
    lr_adv: float = 5e-2,
    n_pgd_steps: int = 16,
    loss_coefs: Dict[str, float] = {"towards": 1, "away": 1},
    num_epochs: int = 1,
):
    adv_layer = layer - 1
    def_layers = [layer, layer - 1, layer - 2]

    updated_model.train()

    params = get_layers_params(updated_model, def_layers, target_modules)
    optimizer = torch.optim.AdamW(params, lr=lr_def)

    updated_module = get_layer_module(updated_model, layer)
    frozen_module = get_layer_module(frozen_model, layer)

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(
            1,
            1,
            updated_model.config.hidden_size,
            dtype=updated_model.dtype,
            device=updated_model.device,
        )
        control_vec = random_vector / torch.norm(random_vector) * steering_coef
        control_vectors_list.append(control_vec)

    num_batches = max_num_batches
    truncation_side = tokenizer.truncation_side
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.truncation_side = "right"

    for epoch in range(num_epochs):
        for idx in tqdm(range(num_batches)):
            control_vec = control_vectors_list[idx]
            unlearn_batch = forget_data_list[idx]
            retain_batch = retain_data_list[idx]

            max_length = 512 if idx == 0 else 768

            unlearn_inputs = tokenizer(
                unlearn_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            retain_inputs = tokenizer(
                retain_batch,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=max_length,
            ).to(updated_model.device)

            if use_lat:
                adv_labels_mask = torch.zeros_like(
                    unlearn_inputs["input_ids"], dtype=bool
                )
                def_labels_mask = torch.zeros_like(
                    retain_inputs["input_ids"], dtype=bool
                )

                for b, example in enumerate(retain_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    def_labels_mask[b, :len_example] = True
                for b, example in enumerate(unlearn_batch):
                    len_example = len(tokenizer(example)["input_ids"])
                    adv_labels_mask[b, :len_example] = True

                # prompt_mask = torch.zeros(len(unlearn_batch), pad_length + 1, dtype=torch.bool)
                pgd_batch = {
                    "def_tokens": retain_inputs["input_ids"].to(updated_model.device),
                    "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device),
                    "adv_labels_mask": adv_labels_mask.to(updated_model.device),
                    "def_labels_mask": def_labels_mask.to(updated_model.device),
                }
                adversary = projected_gradient_descent(
                    updated_model,
                    pgd_batch,
                    layer=layer,
                    epsilon=epsilon,
                    n_pgd_steps=n_pgd_steps,
                    lr_adv=lr_adv,
                    loss_coefs=loss_coefs,
                )

                set_model_grads(adversary, False)
                set_model_grads(updated_model, True)

                # Unlearning loss
                with add_hooks(
                    module_forward_hooks=[
                        (
                            get_layer_module(updated_model, layer),
                            get_adversary_hook(adversary),
                        )
                    ],
                    module_forward_pre_hooks=[],
                ):
                    updated_forget_activations = forward_with_cache(
                        updated_model,
                        unlearn_inputs,
                        module=updated_module,
                        no_grad=False,
                    ).to(updated_model.device)
                    unlearn_loss = torch.nn.functional.mse_loss(
                        updated_forget_activations, control_vec
                    )
            else:
                # Unlearning loss
                updated_forget_activations = forward_with_cache(
                    updated_model, unlearn_inputs, module=updated_module, no_grad=False
                ).to(updated_model.device)
                unlearn_loss = torch.nn.functional.mse_loss(
                    updated_forget_activations, control_vec
                )

            # Retain loss
            updated_retain_activations = forward_with_cache(
                updated_model, retain_inputs, module=updated_module, no_grad=False
            ).to(updated_model.device)
            frozen_retain_activations = forward_with_cache(
                frozen_model, retain_inputs, module=frozen_module, no_grad=True
            ).to(updated_model.device)

            retain_loss = torch.nn.functional.mse_loss(
                updated_retain_activations, frozen_retain_activations
            )
            retain_loss *= alpha

            # Update model
            loss = unlearn_loss + retain_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                print(
                    f"Epoch: {epoch}, Batch: {idx}, Unlearn Loss: {unlearn_loss.item()}, Retain Loss: {retain_loss.item()}, Total Loss: {loss.item()}"
                )

    tokenizer.truncation_side = truncation_side

    return updated_model

In [None]:
frozen_model = AutoModelForCausalLM.from_pretrained(
    model_id, revision="main", torch_dtype=torch.bfloat16
).to(device)
updated_model = run_rmu(
    updated_model,
    frozen_model,  # function to make another version of the model for reference embeddings
    tokenizer,  # tokenizer
    forget_data_list,  # forget data
    retain_data_list,  # retain data
    n_pgd_steps=16,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
  unlearn_loss = torch.nn.functional.mse_loss(
  0%|          | 1/200 [00:08<27:37,  8.33s/it]

Epoch: 0, Batch: 0, Unlearn Loss: 0.08837890625, Retain Loss: 0.0, Total Loss: 0.08837890625


  unlearn_loss = torch.nn.functional.mse_loss(
  6%|▌         | 11/200 [02:02<35:55, 11.41s/it]

Epoch: 0, Batch: 10, Unlearn Loss: 0.06298828125, Retain Loss: 0.232421875, Total Loss: 0.294921875


 10%|█         | 21/200 [03:56<34:11, 11.46s/it]

Epoch: 0, Batch: 20, Unlearn Loss: 0.0634765625, Retain Loss: 0.1796875, Total Loss: 0.2431640625


 14%|█▍        | 28/200 [05:15<32:30, 11.34s/it]