In [None]:
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas
from einops import repeat, rearrange
import time

import importlib
# join the path to the modules to the current working directory

import utils
importlib.reload(utils)


from utils import *

In [None]:
# load model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = "lmsys/vicuna-7b-v1.5"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, use_cache=False, low_cpu_mem_usage=True).to(device).eval()
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
chat_format = { 'user' : ["USER: ", ""], 'assistant' : [" ASSISTANT:", ""], 'sep' : ["", ""]}
# chat_format = { 'user' : ["", ""], 'assistant' : ["", ""], 'sep' : ["", ""]}

In [None]:
# parameters
top_k = None # consider top_k tokens with highest probability
top_p = 0.95 # consider tokens up to cummulative probability top_p
k1 = 15 # k1: no. of candidates in beam 
k2 = 30 # k2: no. of candidates per candidate evaluated 
temperature = 1.0
num_adv_tokens = 40 # number of adversarial tokens to be added to the input sequence
lookahead_length = 10 # number of tokens in the target sequence to be considered in attack objective
batch_size = 16
data_file = "data/harmful_behaviors.csv"
prompt_tag = "goal"
target_tag = "target"
num_samples = 5
num_tokens_to_generate = 40
begin = 1

data = pandas.read_csv("data/harmful_behaviors.csv")
prompts = list(data[prompt_tag])[begin: begin + num_samples]
targets = list(data[target_tag])[begin: begin + num_samples]

In [None]:
formated_prompts = prepare_input(chat_format, prompts)
org_answers = generate(model, tokenizer, formated_prompts, max_new_tokens=num_tokens_to_generate, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k)

for p, oa in zip(formated_prompts[:min(10, num_samples)], org_answers[:min(10, num_samples)]):
    print(f"Prompt: {p}")
    print(f"Original answer: {oa}")
    print()

In [None]:

target_tokens = get_target_tokens(tokenizer, targets, lookahead_length).to(device)
inputs = tokenizer(formated_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)

inputs = generate_and_extend(model, tokenizer, inputs, max_new_tokens=1, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, k=k1)

for i in range(1, num_adv_tokens):
        
    inputs = generate_and_extend(model, tokenizer, inputs, max_new_tokens=1, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, k=k2)
    print(inputs['input_ids'].shape)
    # extend input_ids with target tokens
    total_tokens = {}
    target_tokens_expanded = repeat(target_tokens['input_ids'], 'b t -> b r t', r=k1*k2)
    target_tokens_expanded = rearrange(target_tokens_expanded, 'b r t -> (b r) t')
    total_tokens['input_ids'] = torch.cat([inputs['input_ids'], target_tokens_expanded], dim=-1)
    total_tokens['attention_mask'] = torch.cat([inputs['attention_mask'], torch.ones_like(target_tokens_expanded)], dim=-1)

    # calculate perplexity
    print('before perplexity')
    time.sleep(5)
    perplexity = calc_perplexity(total_tokens, model, batch_size=batch_size)
    time.sleep(5)
    print('after perplexity')

    # choose the best candidate
    inputs = choose_best_candidate(inputs, perplexity, best_k=k1, out_of=k1*k2, device=device)

perplexity = calc_perplexity(inputs, model, batch_size=batch_size)
final_attack = choose_best_candidate(inputs, perplexity, best_k=1, out_of=k1, device=device)

In [None]:
final_attack_sentences = tokenizer.batch_decode(final_attack['input_ids'], skip_special_tokens=True)

attacked_answers = generate(model, tokenizer, final_attack_sentences, max_new_tokens=num_tokens_to_generate, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k)

for p, oa, aa in zip(final_attack_sentences[:min(10, num_samples)], org_answers[:min(10, num_samples)], attacked_answers[:min(10, num_samples)]):
    print(f"attacked prompt: {p}")
    print(f"original answer: {oa}")
    print(f"attacked answer: {aa}")
    print()