In [None]:
from wmdp import get_params, forward_with_cache, get_data
from datasets import load_dataset
from dotenv import load_dotenv
import os
import torch
from typing import Optional, Union, Optional
from latent_at.lat_methods import (
    projected_gradient_descent
)
from omegaconf import OmegaConf
import yaml
from latent_at.laa import (
    clear_hooks,
)
import datetime
import torch
import itertools
from torch.optim import AdamW
import tqdm as tqdm
import wandb
from wmdp_utils import evaluate_phillip, eval_and_log, log, load_model

load_dotenv()
hf_api_key = os.getenv("HF_KEY")
hf_write_key = os.getenv("HF_WRITE_KEY")

### To re-download all datasets:
Run this code block below if you need to re-download all datasets (location: WMDP folder, data subfolder)

In [None]:
dataset_names = ["bio-retain-corpus", "cyber-forget-corpus", "cyber-retain-corpus"] # Everything on cais/wmdp-corpora
for name in dataset_names:
    if not os.path.exists(f"wmdp/data/{name}.jsonl"):
        dataset = load_dataset('cais/wmdp-corpora', name)
        dataset["train"].to_json(f'wmdp/data/{name}.jsonl', batch_size=128)

save_path = f"wmdp/data/bio-forget-corpus.jsonl"
if not os.path.exists(save_path):
    dataset = load_dataset("CindyXWu/wmdp-private", token=hf_api_key)
    dataset["train"].to_json(save_path, batch_size=128)

In [None]:
def load_config(file_path: str) -> dict:
    with open(file_path, 'r') as stream:
        config = yaml.safe_load(stream)
    config = OmegaConf.create(config)
    return config

config: OmegaConf = load_config("wmdp_lat_main_config.yaml")
args = config.rmu_config

In [None]:
# NOTE: USE_CUT CURRENTLY BAKED IN
def train_attack(
    model,
    pgd_config,
    batch: dict[str, torch.Tensor], 
    do_grad_step: bool,
    pca_kwargs: Optional[dict] = None,
) -> None:
    config = pgd_config
    return projected_gradient_descent(
        batch=batch,
        model=model,
        model_layers_module="model.layers",
        layer_list=config.pgd_layers,
        epsilon=config.epsilon,
        learning_rate=config.inner_learning_rate,
        pgd_iterations=config.pgd_iterations_per_step,
        loss_coefs=config.adv_loss_coefs,
        log_loss=do_grad_step,
        device=model.device,
        towards_loss_type=config.towards_loss_type,
        away_loss_type=config.away_loss_type,
        pca_kwargs=pca_kwargs,
    )

In [None]:
def run_rmu(
    forget_data_list,
    retain_data_list,
    config,
):
    """At the moment, support for adversary attacks as steering vectors is limited to the first adversary in the returned list.
    This is the first one specified in the config list, so make sure it's the right one for the right layer!
    """
    args = config.rmu_config
    run_name = f"pgd-{args.use_pgd}_rmutrainablelayers-{args.layer_ids}_rmulayer-{args.layer_id}_pgdlayers-{config.pgd_config.pgd_layers}_eps-{config.pgd_config.epsilon}_defsteps-{config.def_config.model_iterations_per_step}_pdsteps-{config.pgd_config.pgd_iterations_per_step}_alpha-{args.alpha}_{config.wandb_config.additional_run_name if config.wandb_config.additional_run_name else ''}_{datetime.datetime.now().strftime('%d-%m_%H-%M-%S')}"
    logger_params = {
        "name": run_name,
        "project": "rmu_lat",
        "settings": wandb.Settings(start_method="thread"),
        "config": OmegaConf.to_container(config),
        "mode": "online"
    }
    args = config.rmu_config
    wandb.init(**logger_params, entity="quirky_lats_at_mats")
    rmu_config = args
    print("====rmu Config====")
    print("\n".join(f"{k}={v}" for k,v in rmu_config.items()))
    print("=====")

    def disable_model_gradients() -> None:
        for param in updated_model.parameters():
            param.requires_grad_(False)
    
    def enable_model_gradients() -> None:
        for i in range(updated_model.config.num_hidden_layers):
            if i in range(updated_model.config.num_hidden_layers):
                updated_model.get_submodule("model.layers")[i].requires_grad_(True)

    updated_model, tokenizer = load_model(config, hf_api_key)
    frozen_model, tokenizer = load_model(config, hf_api_key)
    updated_model = updated_model.train()
    params = get_params(updated_model, args.layer_ids, args.param_ids)
    optimizer = AdamW(params, lr=args.lr)
    frozen_module = eval(
        args.module_str.format(model_name="frozen_model", layer_id=args.layer_id)
    )
    updated_module = eval(
        args.module_str.format(model_name="updated_model", layer_id=args.layer_id)
    )

    control_vectors_list = []
    for i in range(len(forget_data_list)):
        random_vector = torch.rand(1,1, updated_model.config.hidden_size, dtype=updated_model.dtype, device=updated_model.device)
        control_vec = random_vector / torch.norm(random_vector) * args.steering_coef
        control_vectors_list.append(control_vec)

    num_batches = args.max_num_batches
    truncation_side = tokenizer.truncation_side
    tokenizer.truncation_side="right"
    save_path = f"wmdp_rmulat_models/{run_name}"

    for epoch in range(1):
        print(f"======= Epoch {epoch} =======")
        with tqdm.tqdm(total=num_batches) as pbar:
            for idx in range(num_batches):
                clear_hooks(updated_model)

                if idx % config.eval_steps == 0:
                    updated_model.save_pretrained(save_path)
                    tokenizer.save_pretrained(save_path)
                    with torch.inference_mode():
                        eval_and_log(config, updated_model, idx)

                topic_idx = idx % len(forget_data_list)
                batch_idx = idx // len(forget_data_list)
                control_vec = control_vectors_list[topic_idx]
                unlearn_batch = forget_data_list[topic_idx][batch_idx]
                retain_batch = retain_data_list[topic_idx][batch_idx]

                max_length = 512 if topic_idx == 0 else 768
                unlearn_inputs = tokenizer(
                    unlearn_batch, return_tensors="pt", padding=True, truncation=True, max_length=max_length
                ).to(updated_model.device)
                retain_inputs = tokenizer(
                    retain_batch, return_tensors="pt", padding=True, truncation=True, max_length=512
                ).to(updated_model.device)

                if args.use_pgd:
                    adv_labels_mask = torch.zeros_like(unlearn_inputs["input_ids"], dtype=bool)
                    def_labels_mask = torch.zeros_like(retain_inputs["input_ids"], dtype=bool)
                    for b, example in enumerate(retain_batch):
                        len_example = len(tokenizer(example)["input_ids"])
                        def_labels_mask[b, :len_example] = True
                    for b, example in enumerate(unlearn_batch):
                        len_example = len(tokenizer(example)["input_ids"])
                        adv_labels_mask[b, :len_example] = True

                    pgd_batch = {"def_tokens": retain_inputs["input_ids"].to(updated_model.device), "adv_tokens": unlearn_inputs["input_ids"].to(updated_model.device), "adv_labels_mask": adv_labels_mask.to(updated_model.device), "def_labels_mask": def_labels_mask.to(updated_model.device)}
                    disable_model_gradients()
                    losses, hooks, _ = train_attack(
                        updated_model,
                        config.pgd_config,
                        batch=pgd_batch,
                        do_grad_step=True,
                    )
                    log(losses, idx=idx)
                    enable_model_gradients()

                    # normalized_attacks = attacks[0] / attacks[0].norm(dim=-1, keepdim=True)
                    
                    for hook in hooks:
                        hook.enabled = False
                    for _ in range(config.def_config.model_iterations_per_step):
                        for hook in hooks:
                            hook.enabled = True
                        # Unlearning loss
                        updated_forget_activations = forward_with_cache(
                            updated_model, unlearn_inputs, module=updated_module, no_grad=False
                        ).to(updated_model.device)
                        unlearn_loss = torch.nn.functional.mse_loss(
                            updated_forget_activations, control_vec
                        )

                        # Retain loss
                        for hook in hooks:
                            hook.enabled = False
                        updated_retain_activations = forward_with_cache(
                            updated_model, retain_inputs, module=updated_module, no_grad=False
                        ).to(updated_model.device)
                        frozen_retain_activations = forward_with_cache(
                            frozen_model, retain_inputs, module=frozen_module, no_grad=True
                        ).to(updated_model.device)
                        retain_loss = torch.nn.functional.mse_loss(
                            updated_retain_activations, frozen_retain_activations
                        )
                        retain_loss *= args.alpha

                        # Update model
                        loss = unlearn_loss + retain_loss
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                else:
                    # Unlearning loss
                    updated_forget_activations = forward_with_cache(
                        updated_model, unlearn_inputs, module=updated_module, no_grad=False
                    ).to(updated_model.device)
                    unlearn_loss = torch.nn.functional.mse_loss(
                        updated_forget_activations, control_vec
                    )

                    # Retain loss
                    updated_retain_activations = forward_with_cache(
                        updated_model, retain_inputs, module=updated_module, no_grad=False
                    ).to(updated_model.device)
                    frozen_retain_activations = forward_with_cache(
                        frozen_model, retain_inputs, module=frozen_module, no_grad=True
                    ).to(updated_model.device)
                    retain_loss = torch.nn.functional.mse_loss(
                        updated_retain_activations, frozen_retain_activations
                    )
                    retain_loss *= args.alpha

                    # Update model
                    loss = unlearn_loss + retain_loss
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                loss_dict = {"total_loss": loss.item(), "unlearn_loss": unlearn_loss.item(), "retain_loss": retain_loss.item()}
                log(loss_dict, idx=idx)

                pbar.update(1)

    tokenizer.truncation_side = truncation_side
    # Save model
    path = f"wmdp_rmulat_models/{run_name}"
    clear_hooks(updated_model)
    updated_model.save_pretrained(path)
    tokenizer.save_pretrained(path)
    print(f"Saved model to {path}")
    print("Evaluating saved model with harness")

In [None]:
forget_data_list, retain_data_list = get_data(
    args.forget_corpora,
    args.retain_corpora,
    args.min_len,
    args.max_len,
    args.batch_size,
)


run_rmu(
    forget_data_list,
    retain_data_list,
    config,
)