# Latent Space Adversarial Attack

This notebook demonstrates conducting a latent space adversarial attacks on LLMs. These particualr demo attacks are created using projected gradient descent to make a model more jailbreakable. 

## Imports

In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import wandb
import sys

from dotenv import load_dotenv
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig

load_dotenv()
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

os.chdir("../")
cwd = os.getcwd()
if cwd not in sys.path:
    sys.path.insert(0, cwd)

from latent_at import *
from tasks.harmbench.HarmBenchTask import HarmBenchTask
from tasks.harmbench.FastHarmBenchEvals import run_attack_evals

In [None]:
use_llama2 = True
if use_llama2:  # use llama2-7b
    model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # use llama3-8b
    model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

## Model

In [None]:
model_dtype = torch.bfloat16
device = "cuda"
run_start_evals = False

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    token=hf_access_token,
    torch_dtype=model_dtype
).to(device)


if "Llama-2" in model_name:
    model_type = "llama2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
elif "Llama-3" in model_name:
    model_type = "llama3"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
elif "zephyr" in model_name or "mistral" in model_name:
    model_type = "zephyr"    
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
    tokenizer.pad_token_id = tokenizer.unk_token_id
    tokenizer.padding_side = "left"
else:
    print(model_name)
    raise Exception("Unsupported model type?")

## Data

In [None]:
advbench_data = HarmBenchTask(
    tokenizer=tokenizer,
    gen_batch_size=1,
    cls_batch_size=1,
    device=device,
    data_name="advbench",
    train_test_split=.8
)

harmbench_data = HarmBenchTask(
    tokenizer=tokenizer,
    gen_batch_size=1,
    cls_batch_size=1,
    device=device,
    data_name="harmbench_text",
    train_test_split=.8,
    func_categories=["standard", "contextual"]
)

# train_behaviors = advbench_data.train_behaviors + harmbench_data.train_behaviors  

sys_prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

if model_type == "llama2":
    use_tokenizer_template = True
    custom_prompt_template = None
    custom_completion_template = None
elif model_type == "llama3":
    use_tokenizer_template = False
    custom_prompt_template = "<|start_header_id|>user<|end_header_id|\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    custom_completion_template="{completion}"
else:
    sys_prompt=""
    use_tokenizer_template = False
    custom_prompt_template = "<|user|>\n{prompt}</s> \n <|assistant|>\n"
    custom_completion_template="{completion}"
 
lat_dataset = process_generic_chat_dataset(
    tokenizer,
    dataset="Baidicoot/augmented_advbench_v4",
    adv_column="rejected",
    def_column="chosen",
    split="train",
    use_tokenizer_template=use_tokenizer_template,
    system_prompt=sys_prompt,
    custom_prompt_template=custom_prompt_template,
    custom_completion_template=custom_completion_template
)

lat_dataloader = DataLoader(
    lat_dataset,
    batch_size=4,
    shuffle=True,
    drop_last=True,
    collate_fn=LatentAdversarialTrainingDataCollator(
        tokenizer.pad_token_id,
        truncate_length=2048
    )
)

dataloader = iter(lat_dataloader)
batch = next(dataloader)


## Clean Performance

In [None]:
outputs = model.generate(
    batch["adv_tokens"][[0]].to("cuda"),
    max_length=batch["adv_tokens"].shape[1] + 200,
)

prompt = tokenizer.decode(batch["adv_tokens"][0]).replace('\n', '')
print("Prompt:\n" + prompt)
print()
prompt_response = tokenizer.decode(outputs[0]).replace('\n', '')
print("Completion:\n" + prompt_response[len(prompt):])

## Latent Space Attack

In [None]:
losses, wrappers = projected_gradient_descent(
    batch=batch,  # batch
    model=model,  # model
    model_layers_module="model.layers",  # where the model layers are
    layer=["embedding", 8, 16, 24, 30],  # layers to attack
    epsilon=6.0,  # attack l2 constraint
    l2_regularization=0.0,  # coef for l2 penalty on the attack
    learning_rate=1e-3,  # attack step size
    pgd_iterations=32,  # number of PGD iters
    loss_coefs={"toward": 0.5, "away": 0.5,},  # coefficients for the attack's toward and away losses
    log_loss=True,  # whether to use a log loss instead of a crossentropy one
    device="cuda",
)

## Attacked Performance

In [None]:
for wrapper in wrappers:
    wrapper.enabled = True  # the wrappers should already be enabled, so this is redundant

outputs = model.generate(
    batch["adv_tokens"].to("cuda"),
    max_length=batch["adv_tokens"].shape[1] + 200,
)

prompt = tokenizer.decode(batch["adv_tokens"][0]).replace('\n', '')
print("Prompt:\n" + prompt)
print()
prompt_response = tokenizer.decode(outputs[0]).replace('\n', '')
print("Completion:\n" + prompt_response[len(prompt):])