## Example of Jailbreaking LLaMA-2-chat-7B

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import gc
import math

import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import json
import re
import copy

from llm_attacks.minimal_gcg.opt_utils import token_gradients, sample_control, get_logits, target_loss, get_logits_in_batches
from llm_attacks.minimal_gcg.opt_utils import load_model_and_tokenizer, get_filtered_cands
from llm_attacks.minimal_gcg.string_utils import SuffixManager, load_conversation_template
from llm_attacks import get_nonascii_toks

from livelossplot import PlotLosses # pip install livelossplot
from livelossplot.outputs import MatplotlibPlot
import matplotlib.pyplot as plt

# Set the random seed for NumPy
np.random.seed(20)

# Set the random seed for PyTorch
torch.manual_seed(20)

# If you are using CUDA (i.e., a GPU), also set the seed for it
torch.cuda.manual_seed_all(20)

### Experiment setup

In [None]:
model_name = 'meta-llama/Llama-2-7b-chat-hf'

In [None]:
from transformers import  AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    use_fast=False,
    token='your token'
)
model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                trust_remote_code=True,
                load_in_8bit=True,
                low_cpu_mem_usage=True, 
                use_cache=False,
                token='your token'
            ).eval()

In [None]:
from datasets import load_dataset
dataset = load_dataset("Baidicoot/augmented_advbench_v4")['train']
train_set1 = dataset.select(range(10,20))
test_set = dataset.select(range(20,len(dataset['rejected'])))

In [None]:
train_set = [train_set1[0],train_set1[4],train_set1[7],train_set1[9]]

In [None]:

num_adv_tokens = 2
num_shots = 4
num_queries = 4
num_steps = 100

batch_size = 32
topk = 64

original_num_shots = num_shots

adv_prompts_init = [' a b' for i in range(num_shots)][:num_shots]


demos8 = train_set[:num_shots]
queries2 = train_set1.select(range(num_queries))


demos = []
queries = []
target = []
real_answers = []
queries_questions = []

instruction = "Please answer the following questions step by step."

adv_answer = 'Sure, here'

for index in range(len(demos8)):
    single_question,single_query = demos8[index]['prompt'],demos8[index]['rejected']
    firstpart = '\nQuestion: '+ single_question + '\nAnswer:'
    if('sure' not in single_query.lower()):
        single_query = 'Sure, '+single_query

    secondpart = single_query.split(' ')
    demos.append({'sentence': firstpart+' '+secondpart[0], 'label': ' '+' '.join(secondpart[1:25])})

for index in range(num_queries):
    single_question,single_query = queries2[index]['prompt'],queries2[index]['rejected']
    newsentence = single_query.split('\n')
    firstpart = '\nQuestion: '+ single_question.strip()
    secondpart = '\nAnswer:'
    queries.append(firstpart+secondpart)
    queries_questions.append(firstpart+secondpart)
    target.append(' '+adv_answer)

allow_non_ascii = False # you can set this to True to use unicode tokens

In [None]:
import torch
import copy
class SuffixManager:
    def __init__(self, *, model_name, tokenizer, demos, queries, instruction,
                 targets, adv_prompts, num_adv_tokens): 
                 
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.demos = demos
        self.queries = queries
        self.instruction = instruction
        self.targets = targets
        self.adv_prompts = adv_prompts
        self.num_adv_tokens = num_adv_tokens
        self._demos_slice = [[] for i in range(len(demos))]
        self._control_slice = [[] for i in range(len(demos))]
        self._demos_label_slice = [[] for i in range(len(demos))]
        self._queries_slice = [[] for i in range(len(queries))]
        self._target_slice = [[] for i in range(len(queries))]
        self._loss_slice = [[] for i in range(len(queries))]
    
    def get_prompt(self, adv_prompts=None):
        if adv_prompts is not None:
            self.adv_prompts = adv_prompts
        
        prompts = []
        input = ""
        # i is single query, j is single target
        for i, j  in zip(self.queries ,self.targets):
            newstr = ''
            for x, y in zip(self.demos, self.adv_prompts):
                if(y!='' and y[0]!=' '):
                    newstr += (x['sentence'] + ' ' + y + x['label'])
                else:
                    newstr += (x['sentence'] + y + x['label'])
            prompts.append(self.instruction+newstr+i+j)

        input += self.instruction
        toks = self.tokenizer(input).input_ids
        self._instruction_slice = slice(None, len(toks))
        # demos and labels position
        for index, (demo, adv_prompt) in enumerate(zip(self.demos, self.adv_prompts)):
            input += demo['sentence']#+ 'Answer:' 
            toks = self.tokenizer(input).input_ids
            self._demos_slice[index] = slice(self._instruction_slice.stop, len(toks))
            if(adv_prompt!='' and adv_prompt[0]!=' '):
                input += (' '+adv_prompt)
            else:
                input += adv_prompt
            toks = self.tokenizer(input).input_ids


            if self._demos_slice[index].stop + self.num_adv_tokens != len(toks):
                self._control_slice[index] = slice(self._demos_slice[index].stop-1, len(toks))
            else:
                self._control_slice[index] = slice(self._demos_slice[index].stop, len(toks))
            input += demo['label']
            toks = self.tokenizer(input).input_ids
            self._demos_label_slice[index] = slice(self._control_slice[index].stop, len(toks))

        # query position
        for index, (query, target) in enumerate(zip(self.queries, self.targets)):
            temp_input = copy.deepcopy(input)
            temp_input += query
            toks = self.tokenizer(temp_input).input_ids
            self._queries_slice[index] = slice(self._demos_label_slice[-1].stop, len(toks))
            temp_input += target
            toks = self.tokenizer(temp_input).input_ids
            self._target_slice[index] = slice(self._queries_slice[index].stop, len(toks))
            self._loss_slice[index] = slice(self._queries_slice[index].stop-1, len(toks)-1)

        return prompts

    
    def get_input_ids(self, adv_prompts=None):

        prompt_list = self.get_prompt(adv_prompts=adv_prompts)
        input_ids_list = []
        for prompt, target_slice in zip(prompt_list, self._target_slice):
            toks = self.tokenizer(prompt).input_ids

            input_ids_list.append(torch.tensor(toks[:target_slice.stop]))
        return input_ids_list
    
    def get_input_ids_output(self, adv_prompts=None):

        prompt_list = self.get_prompt(adv_prompts=adv_prompts)
        input_ids_list = []
        for prompt, target_slice in zip(prompt_list, self._target_slice):
            toks = self.tokenizer(prompt).input_ids
            input_ids_list.append(torch.tensor(toks[:target_slice.start]))
        return input_ids_list
    
    def get_prompt_ids(self, adv_prompts=None):

        prompt_list = self.get_prompt(adv_prompts=adv_prompts)
        toks = self.tokenizer(prompt_list[0]).input_ids
        input_ids_list = torch.tensor(toks[:self._queries_slice[0].start])
        return input_ids_list

In [None]:
suffix_manager = SuffixManager(model_name=model_name,
                                tokenizer=tokenizer, 
                                demos=demos, 
                                queries=queries, 
                                instruction=instruction, 
                                targets=target, 
                                adv_prompts=adv_prompts_init,
                                num_adv_tokens=num_adv_tokens,
                                )

In [None]:
not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer) 
adv_suffix = adv_prompts_init
best_loss = math.inf
best_adv_suffix = None
losses_list = []

input_ids_list = suffix_manager.get_input_ids(adv_prompts=adv_prompts_init)
input_ids_list = [input_ids for input_ids in input_ids_list]

In [None]:
def get_logits_in_batches(model, tokenizer, input_ids_list, control_slice_list, test_controls, batch_size, return_ids=True, num_adv_tokens=2, num_shots=None, target_slice= None):
    total_losses = []

    # Number of batches
    num_batches = (len(test_controls) + batch_size - 1) // batch_size


    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size

        # Extract the current batch of input_ids
        batch_test_controls = test_controls[start_idx:end_idx]

        # Call get_logits for this batch
        logits, ids = get_logits(
            model=model,
            tokenizer=tokenizer,
            input_ids_list=input_ids_list,
            control_slice_list=control_slice_list,
            test_controls=batch_test_controls,
            return_ids=return_ids,
            batch_size=batch_size,  # This may or may not be needed depending on how `get_logits` is implemented
            num_adv_tokens=num_adv_tokens,
            num_shots=num_shots
        )


        # losses = target_loss_new(logits, ids, target_slice)
        losses = target_loss(logits, ids, target_slice)
        for loss in losses:
            total_losses += [loss.item()]
        logits = None
        ids = None
        losses=None
        torch.cuda.empty_cache()

    return total_losses

In [None]:
def get_filtered_cands(tokenizer,filter_cand, control_cand, model_name, num_tokens):
    cands, count = [], 0
    for i in range(len(control_cand)):
        valid = True
        # decoded_str = tokenizer.decode(control_cand[i], skip_special_tokens=True)
        if filter_cand:
            for token in control_cand[i]:
                token_ids = torch.tensor(tokenizer(token, add_special_tokens=False).input_ids)
                if token_ids.size()[0] != num_tokens:
                    valid = False
            if valid:
                cands.append(control_cand[i])
            else:
                count += 1
        else:
            cands.append(control_cand[i])

    if filter_cand:
        cands = cands + [cands[-1]] * (len(control_cand) - len(cands))
        print(f"Warning: {round(count / len(control_cand), 2)} control candidates were not valid")
    return cands

In [None]:
num_steps = 200

In [None]:
filename = 'directory' 
plotlosses = PlotLosses(outputs=[MatplotlibPlot(figpath =filename)])

# get candidate token list
not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer) 
adv_suffix = adv_prompts_init
best_loss = math.inf
best_adv_suffix = None
losses_list = []


for i in range(num_steps):
    input_ids_list = suffix_manager.get_input_ids(adv_prompts=adv_suffix)
    input_ids_list = [input_ids for input_ids in input_ids_list]
    coordinate_grad = token_gradients(model, 
                input_ids_list, 
                suffix_manager._control_slice, 
                suffix_manager._target_slice, 
                suffix_manager._loss_slice,
                )

    with torch.no_grad():

        adv_suffix_tokens = None
        for _control_slice in suffix_manager._control_slice:
            # print(_control_slice)
            if adv_suffix_tokens == None:
                adv_suffix_tokens = input_ids_list[0][_control_slice] 
            else: 
                adv_suffix_tokens = torch.cat([ adv_suffix_tokens, input_ids_list[0][_control_slice]])

        new_adv_suffix_toks = sample_control(adv_suffix_tokens, 
                        coordinate_grad, 
                        batch_size, 
                        topk=topk, 
                        temp=1, 
                        not_allowed_tokens=not_allowed_tokens)
        

        new_adv_suffix = []
        for index, each_toks in enumerate(new_adv_suffix_toks):
            new_adv_suffix.append([])
            for i in range(num_shots):
                new_adv_suffix[index].append(tokenizer.decode(each_toks[i*num_adv_tokens:i*num_adv_tokens+num_adv_tokens]))

        new_adv_suffix = get_filtered_cands(tokenizer, 
            filter_cand=True, 
            control_cand=new_adv_suffix,
            model_name=model_name,
            num_tokens=num_adv_tokens)
    
        losses = get_logits_in_batches(
            model=model,
            tokenizer=tokenizer,
            input_ids_list=input_ids_list,
            control_slice_list=suffix_manager._control_slice,
            test_controls=new_adv_suffix,
            batch_size=64,
            num_adv_tokens=num_adv_tokens,
            num_shots=num_shots,
            target_slice = suffix_manager._target_slice
        )

        min_loss = min(losses)

        best_new_adv_suffix_id = losses.index(min_loss)

        best_new_adv_suffix = new_adv_suffix[best_new_adv_suffix_id]

        current_loss = losses[best_new_adv_suffix_id]
        losses_list.append(current_loss)

        best_new_adv_suffix_ids = tokenizer(best_new_adv_suffix , add_special_tokens=False).input_ids

        adv_suffix = best_new_adv_suffix


    if len(adv_suffix) != num_shots:
        print(adv_suffix)
        print("error")
        # print(adv_suffix)
        break
    # Create a dynamic plot for the loss.
    plotlosses.update({'Loss': current_loss})
    plotlosses.send() 
    print(adv_suffix)
    if current_loss < best_loss:
        best_loss = current_loss
        best_adv_suffix = adv_suffix
    
    print(f"\nCurrent Suffix:{adv_suffix}", end='\r')
    
    
    del coordinate_grad, adv_suffix_tokens ; gc.collect()
    torch.cuda.empty_cache()

# plt.savefig(filename)

In [None]:
print(best_loss) 
print(best_adv_suffix)

In [None]:
demos = []
queries = []
demos_un = ""
for index in range(len(demos8)):
    single_question,single_query = demos8[index]['prompt'],demos8[index]['rejected']
    firstpart = '\nQuestion: '+ single_question + '\nAnswer:'
    if('sure' not in single_query.lower()):
        single_query = 'Sure, '+single_query

    secondpart = single_query.split(' ')
    demos.append({'sentence': firstpart+' '+secondpart[0], 'label': ' '+' '.join(secondpart[1:100])})
    demos_un += firstpart+' '+' '.join(secondpart[0:100])

quer = []
questions = []
i = 0
for index in range(len(test_set['prompt'])):
    i+=1
    if(i==101):
        break
    single_question,single_query = test_set[index]['prompt'],test_set[index]['rejected']
    newsentence = single_query.split('\n')
    firstpart = '\nQuestion: '+ single_question.strip()
    secondpart = '\nAnswer:'
    questions.append(firstpart)
    queries.append(firstpart+secondpart)
    queries_questions.append(firstpart+secondpart)
    quer.append(demos_un+firstpart+secondpart)
    target.append(' '+adv_answer)

In [None]:
suffix_manager = SuffixManager(model_name=model_name,
                                tokenizer=tokenizer, 
                                demos=demos, 
                                queries=queries, 
                                instruction=instruction, 
                                targets=target, 
                                adv_prompts=best_adv_suffix,
                                num_adv_tokens=num_adv_tokens,
                                )

In [None]:
output = suffix_manager.get_input_ids_output(adv_prompts=best_adv_suffix)
tokenizer.decode(output[0])

In [None]:
for i in range(len(output)):
    outputs = model.generate((output[i].unsqueeze(0)).to(model.device), max_new_tokens=100,return_dict_in_generate=True, output_scores=True)
    print(questions[i])
    print(tokenizer.decode(outputs['sequences'][0][len(output[i]):]))
    print("="*50)

In [None]:
for i in range(len(quer)):
    output1 = instruction+quer[i]
    input_tensor = tokenizer(output1, return_tensors="pt",add_special_tokens=True)
    outputs = model.generate(input_tensor['input_ids'].to(model.device), max_new_tokens=100,return_dict_in_generate=True, output_scores=True)
    print(questions[i])
    print(tokenizer.decode(outputs['sequences'][0][input_tensor['input_ids'].shape[-1]:]))
    print("="*50)