In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import wandb
import sys

from dotenv import load_dotenv
from peft import AutoPeftModelForCausalLM
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig

load_dotenv()
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

os.chdir("../")
cwd = os.getcwd()
if cwd not in sys.path:
    sys.path.insert(0, cwd)

from latent_at import *
from tasks.harmbench.HarmBenchTask import HarmBenchTask


# Get model

In [None]:
dtype = torch.bfloat16
device="cuda"

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=hf_access_token)
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = "left"

llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=hf_access_token, torch_dtype=dtype).to(device)
peft_config = LoraConfig(
    r=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "up_proj",
        "down_proj",
    ],
)
llama = get_peft_model(llama, peft_config)

# Get data

In [None]:
advbench_data = HarmBenchTask(
    tokenizer=tokenizer,
    gen_batch_size=1,
    cls_batch_size=1,
    device=device,
    data_name="advbench",
    train_test_split=.8,
)

sys_prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

augment = False

def make_lat_dataloader(batch_size=16):
    lat_dataset = tokenized_behavior_dataset(
        advbench_data.train_behaviors,
        tokenizer,
        good_refuse=True,
        bad_refuse=False,
        system_prompt=sys_prompt
    )
    return DataLoader(
        lat_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=LatentAdversarialTrainingDataCollator(
            tokenizer.pad_token_id,
            truncate_length=2048
        )
    )

sft_dataset = process_generic_chat_dataset(
    tokenizer,
    dataset="VH1213141516/benign_data_v1",
    adv_column=None,
    def_column="response",
    split="train",
    use_tokenizer_template=True,
    system_prompt=sys_prompt
)


def make_sft_dataloader(batch_size=16):
    return DataLoader(
    sft_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=LatentAdversarialTrainingDataCollator(
        tokenizer.pad_token_id,
        truncate_length=2048
    )
)

def post_def_callback(losses, epoch):
    wandb.log(losses)

# Set up trainer

In [None]:
pgd_trainer = ProjectedGradLAT(
    model=llama,
    dataloader=make_lat_dataloader(batch_size=2),
    pgd_layers={
            0: 8,
            1: 16,
            2: 24,
            3: 30
        },
    model_layers=range(32),#list(range(earliest_layer, llama.config.num_hidden_layers)),
    epsilon=6.0,
    outer_learning_rate=2e-5,
    inner_learning_rate=5e-2,
    init_callback=None,
    post_def_callback=post_def_callback,
    model_iterations_per_step=4,
    only_train_lora=True,
    model_layers_module="base_model.model.model.layers",
    sft_dataloader=make_sft_dataloader(batch_size=2),
    adv_loss_coefs={
        "away": 0.5,
        "toward": 0.5
    },
    def_loss_coefs={
        "away": 0.2857142857142857,
        "sft": 0.5714285714285714,
        "toward": 0.14285714285714285
    },
    pgd_iterations_per_step=16,
    num_steps=150,
    time_limit=None,
    max_batch_per_acc=4,
    N_checkpoints=None,
    checkpoint_dir="data/harmbench_checkpoints/TEST",
    device="cuda",
    reinitialize_dev_optim=True,
)

# Run!

In [None]:
pgd_trainer.train(project_name="harmbench_TEST")
pgd_trainer.model.save_pretrained("save_here_DEBUG")