In [None]:
import sys
import os
import torch
from typing import Optional, Union, Optional

from omegaconf import OmegaConf
import yaml
from datasets import load_dataset
from dotenv import load_dotenv
import datetime
import torch
from torch.optim import AdamW
import tqdm as tqdm
import wandb

notebook_dir = os.path.dirname(os.path.abspath("__file__"))
project_dir = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_dir not in sys.path:
    sys.path.insert(0, project_dir)

from wmdp_utils import evaluate_during_training, eval_and_log, log, load_model
from latent_at.wmdp.rmu.utils import get_params, forward_with_cache, get_data
from latent_at.laa import clear_hooks
from latent_at.lat_methods import projected_gradient_descent

load_dotenv()
hf_api_key = os.getenv("HF_KEY")
hf_write_key = os.getenv("HF_WRITE_KEY")

### To re-download all datasets:
Run this code block below if you need to re-download all datasets (location: WMDP folder, data subfolder)

In [None]:
dataset_names = ["bio-retain-corpus", "cyber-forget-corpus", "cyber-retain-corpus"] # Everything on cais/wmdp-corpora
for name in dataset_names:
    if not os.path.exists(f"wmdp_data/{name}.jsonl"):
        dataset = load_dataset('cais/wmdp-corpora', name)
        dataset["train"].to_json(f'wmdp_data/{name}.jsonl', batch_size=128)

"""TODO: Replace with bio loading code here from local file"""
# save_path = f"wmdp_data/bio-forget-corpus.jsonl"
# if not os.path.exists(save_path):
#     dataset = load_dataset("CindyXWu/wmdp-private", token=hf_api_key)
#     dataset["train"].to_json(save_path, batch_size=128)

In [None]:
def load_config(file_path: str) -> dict:
    with open(file_path, 'r') as stream:
        config = yaml.safe_load(stream)
    config = OmegaConf.create(config)
    return config

config: OmegaConf = load_config("wmdp_rmu_lat_config.yaml")
args = config.rmu_config

In [None]:
def train_attack(
    model,
    pgd_config,
    batch: dict[str, torch.Tensor], 
    do_grad_step: bool,
    pca_kwargs: Optional[dict] = None,
) -> None:
    config = pgd_config
    return projected_gradient_descent(
        batch=batch,
        model=model,
        model_layers_module="model.layers",
        layer_list=config.pgd_layers,
        epsilon=config.epsilon,
        learning_rate=config.inner_learning_rate,
        pgd_iterations=config.pgd_iterations_per_step,
        loss_coefs=config.adv_loss_coefs,
        log_loss=do_grad_step,
        device=model.device,
        pca_kwargs=pca_kwargs,
    )

Helper functions

In [None]:
def initialize_models(config):
    updated_model, tokenizer = load_model(hf_api_key)
    frozen_model, _ = load_model(hf_api_key)
    updated_model = updated_model.train()
    return updated_model, tokenizer, frozen_model


def create_control_vectors(forget_data_list, updated_model, steering_coef):
    """Make an RMU control vector separately for each dataset (forget_data_list contains both bio and cyber)."""
    control_vectors_list = []
    for _ in range(len(forget_data_list)):
        random_vector = torch.rand(1, 1, updated_model.config.hidden_size, dtype=updated_model.dtype, device=updated_model.device)
        control_vec = random_vector / torch.norm(random_vector) * steering_coef
        control_vectors_list.append(control_vec)
    return control_vectors_list


def prepare_inputs(tokenizer, unlearn_batch, retain_batch, topic_idx, device):
    max_length = 512 if topic_idx == 0 else 768
    unlearn_inputs = tokenizer(unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(device)
    retain_inputs = tokenizer(retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    return unlearn_inputs, retain_inputs


def prepare_pgd_batch(unlearn_inputs, retain_inputs, updated_model, tokenizer, retain_batch, unlearn_batch):
    adv_labels_mask = torch.zeros_like(unlearn_inputs["input_ids"], dtype=bool)
    def_labels_mask = torch.zeros_like(retain_inputs["input_ids"], dtype=bool)
    for b, example in enumerate(retain_batch):
        def_labels_mask[b, :len(tokenizer(example)["input_ids"])] = True
    for b, example in enumerate(unlearn_batch):
        adv_labels_mask[b, :len(tokenizer(example)["input_ids"])] = True

    return {
        "def_tokens": retain_inputs["input_ids"].to(updated_model.device),
        "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device),
        "adv_labels_mask": adv_labels_mask.to(updated_model.device),
        "def_labels_mask": def_labels_mask.to(updated_model.device)
    }


def disable_model_gradients(model):
    for param in model.parameters():
        param.requires_grad_(False)


def enable_model_gradients(model):
    for i in range(model.config.num_hidden_layers):
        model.get_submodule("model.layers")[i].requires_grad_(True)


def compute_forget_loss(updated_model, unlearn_inputs, control_vec, module):
    activations = forward_with_cache(updated_model, unlearn_inputs, module=module, no_grad=False).to(updated_model.device)
    return torch.nn.functional.mse_loss(activations, control_vec)


def compute_retain_loss(updated_model, retain_inputs, frozen_model, alpha, updated_module, frozen_module):
    updated_retain_activations = forward_with_cache(updated_model, retain_inputs, module=updated_module, no_grad=False).to(updated_model.device)
    frozen_retain_activations = forward_with_cache(frozen_model, retain_inputs, module=frozen_module, no_grad=True).to(updated_model.device)
    return torch.nn.functional.mse_loss(updated_retain_activations, frozen_retain_activations) * alpha

The function for computing RMU loss comes from our codebase with some formalization to match the form of other loss functions, but a more simplified version can be found in the helper functions above for each of the two functions (our codebase will have 'RMU' explicit in the function name).

In [None]:
from latent_at.lat_helpers import compute_rmu_retain_loss, compute_rmu_forget_loss


def run_rmu(
    forget_data_list,
    retain_data_list,
    config,
):
    """At the moment, support for adversary attacks as steering vectors is limited to the first adversary in the returned list.
    This is the first one specified in the config list, so make sure it's the right one for the right layer!
    """
    args = config.rmu_config
    run_name = f"rmutrainablelayers-{args.layer_ids}_rmulayer-{args.layer_id}_pgdlayers-{config.pgd_config.pgd_layers}_eps-{config.pgd_config.epsilon}_defsteps-{config.def_config.model_iterations_per_step}_pdsteps-{config.pgd_config.pgd_iterations_per_step}_alpha-{args.alpha}_{datetime.datetime.now().strftime('%d-%m_%H-%M-%S')}"
    logger_params = {
        "name": run_name,
        "project": "rmu_lat",
        "settings": wandb.Settings(start_method="thread"),
        "config": OmegaConf.to_container(config),
        "mode": "online"
    }
    wandb.init(**logger_params, entity="quirky_lats_at_mats")
    print("====rmu Config====")
    print("\n".join(f"{k}={v}" for k,v in config.rmu_config.items()))
    print("=====")

    updated_model, tokenizer, frozen_model = initialize_models(config)

    params = get_params(updated_model, args.layer_ids, args.param_ids)
    optimizer = AdamW(params, lr=args.lr)

    frozen_module = eval(args.module_str.format(model_name="frozen_model", layer_id=args.layer_id))
    updated_module = eval(args.module_str.format(model_name="updated_model", layer_id=args.layer_id))

    control_vectors_list = create_control_vectors(forget_data_list, updated_model, args.steering_coef)

    num_batches = args.max_num_batches
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"
    save_path = f"wmdp_rmulat_models/{run_name}"

    for epoch in range(1):
        print(f"======= Epoch {epoch} =======")
        with tqdm.tqdm(total=num_batches) as pbar:
            for idx in range(num_batches):
                clear_hooks(updated_model)

                if idx % config.eval_steps == 0:
                    updated_model.save_pretrained(save_path)
                    tokenizer.save_pretrained(save_path)
                    with torch.inference_mode():
                        eval_and_log(config, updated_model, idx)

                topic_idx, batch_idx = idx % len(forget_data_list), idx // len(forget_data_list)
                control_vec = control_vectors_list[topic_idx]
                unlearn_batch = forget_data_list[topic_idx][batch_idx]
                retain_batch = retain_data_list[topic_idx][batch_idx]

                unlearn_inputs, retain_inputs = prepare_inputs(tokenizer, unlearn_batch, retain_batch, topic_idx, updated_model.device)

                if args.use_pgd:
                    pgd_batch = prepare_pgd_batch(unlearn_inputs, retain_inputs, updated_model, tokenizer, retain_batch, unlearn_batch)
                    disable_model_gradients()
                    losses, hooks, _ = train_attack(
                        updated_model,
                        config.pgd_config,
                        batch=pgd_batch,
                        do_grad_step=True,
                    )
                    log(losses, idx=idx)
                    enable_model_gradients(updated_model)
                    
                    for hook in hooks:
                        hook.enabled = False
                    for _ in range(config.def_config.model_iterations_per_step):
                        # Keep PGD perturbations on for forget loss only
                        for hook in hooks:
                            hook.enabled = True
                        unlearn_loss = compute_rmu_forget_loss(updated_model, unlearn_inputs, control_vec, updated_module)

                        for hook in hooks:
                            hook.enabled = False
                        retain_loss = compute_rmu_retain_loss(updated_model, retain_inputs, frozen_model, args.alpha, updated_module, frozen_module)

                        loss = unlearn_loss + retain_loss
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                else: # Just code from CUT
                    unlearn_loss = compute_rmu_forget_loss(updated_model, unlearn_inputs, control_vec, updated_module)
                    # Can also use compute_forget_loss()

                    retain_loss = compute_rmu_retain_loss(updated_model, retain_inputs, frozen_model, args.alpha, updated_module, frozen_module)
                    # Can also use compute_retain_loss()

                    loss = unlearn_loss + retain_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                loss_dict = {"total_loss": loss.item(), "unlearn_loss": unlearn_loss.item(), "retain_loss": retain_loss.item()}
                log(loss_dict, idx=idx)

                pbar.update(1)

    tokenizer.truncation_side = truncation_side
    # Save model
    path = f"wmdp_rmulat_models/{run_name}"
    clear_hooks(updated_model)
    updated_model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    print(f"Saved model to {path}")
    print("Evaluating saved model with harness")

In [None]:
forget_data_list, retain_data_list = get_data(
    args.forget_corpora,
    args.retain_corpora,
    args.min_len,
    args.max_len,
    args.batch_size,
)

run_rmu(
    forget_data_list,
    retain_data_list,
    config,
)