In [None]:
import os
import tqdm as tqdm
import sys
import yaml
import datetime

from typing import Union, Optional
from datasets import load_dataset, Dataset
from dotenv import load_dotenv
from omegaconf import OmegaConf
from itertools import cycle
from peft import AutoPeftModelForCausalLM
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig

notebook_dir = os.path.dirname(os.path.abspath("__file__"))
project_dir = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_dir not in sys.path:
    sys.path.insert(0, project_dir)

from latent_at.laa.attacks import *
from latent_at.lat_datasets import process_pretokenized_dataset, WMDPLATTargetedDataCollator, WMDPLATDataCollator
from latent_at.laa import add_hooks, clear_hooks
from latent_at.lat_methods import WMDPProjectedGradLAT
from latent_at.wmdp.rmu.utils import get_data
from tasks.general_capabilities.multiple_choice_tasks import WMDPTask, MMLUTask
from tasks.wmdp.WMDP_MCTask import WMDP_MCTask
from wmdp_utils import evaluate_during_training, eval_and_log, load_model

load_dotenv()
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

## Load config file as OmegaConf object

In [None]:
def load_config(file_path: str) -> dict:
    with open(file_path, 'r') as stream:
        config = yaml.safe_load(stream)
    return config

config: dict = load_config("wmdp_ga_lat_config.yaml")
config = OmegaConf.create(config)

## Get model
Make sure you have a .env file in the same directory as this notebook with a HuggingFace read token with access to Llama 7B chat. Set this to the variable `HUGGINGFACE_API_KEY`.

In [None]:
model_dict = {
    "LLAMA_7B": "meta-llama/Llama-2-7b-chat-hf",
    "LLAMA_34B": "meta-llama/Llama-2-34b-chat-hf",
    "ZEPHYR_7B": "HuggingFaceH4/zephyr-7b-beta",
    "ZEPHYR_34B": "HuggingFaceH4/zephyr-34b-beta",
}


model, tokenizer = load_model(hf_access_token)

## Get data

Download data if first time loading to system

In [None]:
dataset_names = ["bio-retain-corpus", "cyber-forget-corpus", "cyber-retain-corpus"] # Everything on cais/wmdp-corpora
for name in dataset_names:
    if not os.path.exists(f"wmdp/data/{name}.jsonl"):
        dataset = load_dataset('cais/wmdp-corpora', name)
        dataset["train"].to_json(f'wmdp_data/{name}.jsonl', batch_size=128)

"""TODO - add WMDP bio forget loading code here"""

In [None]:
def load_targeted_wmdp_data(config: OmegaConf) -> tuple[Dataset, Dataset]:
    """
    Forget and retain datasets differ. Load forget and retain then recombine into columns 'adv_tokens' and 'def_tokens' in new Huggingface Dataset object.

    Supports bio/cyber WMDP retain corpora and WikiText (paired with both bio and cyber unlearn) as options for retain dataset. See config line 'retain_corpora'.
    """
    retain_corpora = config.retain_corpora.split(",")
    forget_corpora = config.forget_corpora.split(",")

    # Load and rename datasets for 'forget' corpora
    # Only bio forget needs to be locally loaded
    hf_location = "cais/wmdp-corpora"
    for d in forget_corpora:
        dataset_path = f"wmdp_data/{d}.jsonl"
        if d == "bio-forget-corpus":
            bio_forget = load_dataset('json', data_files=dataset_path, split='train')
            bio_forget = bio_forget.rename_column('text', 'adv_tokens')
        elif d == "cyber-forget-corpus":
            cyber_forget = load_dataset(hf_location, name=d, split='train')
            cyber_forget = cyber_forget.rename_column('text', 'adv_tokens')

    if retain_corpora == ['bio-retain-corpus', 'cyber-retain-corpus']:
        for d in retain_corpora:
            retain_dataset = load_dataset(hf_location, name=d, split='train')
            retain_dataset = retain_dataset.rename_column('text', 'def_tokens')
            if d == "bio-retain-corpus":
                bio_retain = retain_dataset
            elif d == "cyber-retain-corpus":
                cyber_retain = retain_dataset
    elif retain_corpora == ['wikitext', 'wikitext']:
        bio_retain = cyber_retain = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        bio_retain = bio_retain.rename_column('text', 'def_tokens')
        cyber_retain = cyber_retain.rename_column('text', 'def_tokens')

    def merge_rows(example1, example2):
        return {'adv_tokens': example1['adv_tokens'], 'def_tokens': example2['def_tokens']}

    min_bio_length, min_cyber_length = min(len(bio_forget), len(bio_retain)), min(len(cyber_forget), len(cyber_retain))
    bio_dataset = bio_forget.select(range(min_bio_length)).map(
        lambda x, 
        idx: merge_rows(x, bio_retain[idx]),
        with_indices=True,
    ).remove_columns(['title', 'abstract', 'doi'])
    cyber_dataset = cyber_forget.select(range(min_cyber_length)).map(
        lambda x, 
        idx: merge_rows(x, cyber_retain[idx]), 
        with_indices=True
    )

    return bio_dataset, cyber_dataset


def make_targeted_wmdp_dataloaders(
    config: OmegaConf,
    bio_dataset: Dataset,
    cyber_dataset: Dataset, 
    tokenizer: AutoTokenizer, 
) -> tuple[DataLoader, DataLoader]:
    """
    Don't use for anything except WMDP exactly as loaded with tuple list [bio, cyber].

    Returns:
        Two dataloaders: one for bio and one for cyber.
    """
    wmdp_bio_dataloader = DataLoader(
        bio_dataset, 
        shuffle=True, 
        batch_size=config.lat_batch_size, 
        collate_fn=WMDPLATTargetedDataCollator(tokenizer, truncate_length=config.data_truncate_length), 
        drop_last=True,
    )
    wmdp_cyber_dataloader = DataLoader(
        cyber_dataset, 
        shuffle=True, 
        batch_size=config.lat_batch_size, 
        collate_fn=WMDPLATTargetedDataCollator(tokenizer, truncate_length=config.data_truncate_length), 
        drop_last=True,
    )
    return wmdp_bio_dataloader, wmdp_cyber_dataloader


def make_untargeted_wmdp_dataloaders(
    config: OmegaConf,
    data: Union[list[str], list[list[str]]], 
    tokenizer: AutoTokenizer, 
    sft: Optional[bool] = False,
) -> Union[DataLoader, tuple[DataLoader, DataLoader]]:
    """
    Don't use for anything except WMDP unlearn corpora exactly as loaded with tuple list [bio, cyber].
    Used for two things: creating SFT dataloaders, and creating WMDP dataloaders in the case where defence and attacker both train on the WMDP unlearn corpora.

    In the paper, by default, we use targeted, where the defence trains toward WikiText and the attack trains toward WMDP unlearn corpora. The away losses are gradient ascent on these same datasets but swapped between attack and defence.

    Args:
        config: OmegaConf object created from yaml file.
        data: Each list this contains will be one of the bio/cyber datasets.
        sft: If True, data list is two copies of the same dataset. This will only be used to generate supervised fine tuning dataloader for SFT portion of R2D2 loss in LAT.
    Returns:
        Dataloaders.
    """
    if sft:
        return DataLoader(
            data, 
            shuffle=True, 
            batch_size=config.sft_batch_size, 
            collate_fn=WMDPLATDataCollator(tokenizer, truncate_length=config.data_truncate_length), 
            drop_last=True,
        )
        
    wmdp_bio_dataloader = DataLoader(
        data[0], 
        shuffle=True, 
        batch_size=config.lat_batch_size, 
        collate_fn=WMDPLATDataCollator(tokenizer, truncate_length=config.data_truncate_length), 
        drop_last=True,
    )
    wmdp_cyber_dataloader = DataLoader(
        data[1], 
        shuffle=True, 
        batch_size=config.lat_batch_size, 
        collate_fn=WMDPLATDataCollator(tokenizer, truncate_length=config.data_truncate_length), 
        drop_last=True,
    )
    return wmdp_bio_dataloader, wmdp_cyber_dataloader


def load_sft_dataset(config) -> list[str]:
    """Works for wikitext and alpaca."""
    if config.sft_corpora == "wikitext":
        raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    elif config.sft_corpora == "alpaca":
        raw_data = load_dataset("tatsu-lab/alpaca", "default", split="train")
    data = []
    for x in raw_data:
        data.append(str(x['text']))
    return data


assert config.forget_corpora == 'bio-forget-corpus,cyber-forget-corpus'
bio_dataset, cyber_dataset = load_targeted_wmdp_data(config)
bio_dataloader, cyber_dataloader = make_targeted_wmdp_dataloaders(config, bio_dataset, cyber_dataset, tokenizer)

if config.def_config.def_loss_coefs.sft > 0:
    sft_dataset: list[str] = load_sft_dataset(config)
    sft_dataloader = make_untargeted_wmdp_dataloaders(config, sft_dataset, tokenizer, sft=True)
    sft_dataloader = cycle(sft_dataloader)
else:
    assert config.def_config.def_loss_coefs["sft"] == 0
    sft_dataloader = None

## Set up trainer

In [None]:
run_lat = WMDPProjectedGradLAT(
    model,
    dataloader=bio_dataloader if config.use_just_bio else cyber_dataloader,
    pgd_layers=config.adv_optim.pgd_layers,
    model_layers=config.model_layers,
    epsilon=config.adv_optim.epsilon,
    init_callback=eval_and_log,
    post_adv_callback=eval_and_log,
    post_def_callback=eval_and_log,
    outer_learning_rate=config.def_config.outer_learning_rate,
    inner_learning_rate=config.pgd_config.inner_learning_rate,
    num_steps=config.num_epochs,
    pgd_iterations_per_step=config.pgd_config.pgd_iterations_per_step,
    model_iterations_per_step=config.def_config.model_iterations_per_step,
    model_layers_module=config.adv_optim.model_layers_module,
    only_train_lora=config.only_train_lora,
    sft_dataloader=sft_dataloader,
    adv_loss_coefs=config.pgd_config.adv_loss_coefs,
    def_loss_coefs=config.def_config.def_loss_coefs,
    max_batch_per_acc=config.minibatch_size,
    clip_grad=config.clip_grad,
    reinitialize_dev_optim=config.def_config.reinitialize_def_optim,
    time_limit=config.time_limit,
    device="cuda",
    pca_kwargs = config.pgd_config.pca_kwargs,
    config=config, # Use for more flexible parameter passing
)

Set run name and train. In practice, the hyperparameters set for default WandB run name are those that are varied most often, but these should be changed as necessary.

In [None]:
run_name = f"{config.model_name}_layers-{config.pgd_config.pgd_layers}_epsilon-{config.pgd_config.epsilon}_defsft-{config.def_config.def_loss_coefs.sft:.2f}_deftowards-{config.def_config.def_loss_coefs.toward:.2f}_defaway-{config.def_config.def_loss_coefs.away:.2f}_pgdtowards-{config.pgd_config.adv_loss_coefs.toward:.2f}_defits-{config.def_config.model_iterations_per_step}_pgdits-{config.pgd_config.pgd_iterations_per_step}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

run_lat.train("wmdp_lat", name=run_name)