In [8]:
from copy import deepcopy
import colorama
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed, AutoModel
from transformers.cache_utils import DynamicCache
from datasets import load_dataset
import huggingface_hub

# Setting reproducibility
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
set_seed(seed)


# Utility lambdas
GREEN = lambda x: colorama.Fore.GREEN + x + colorama.Fore.RESET
YELLOW = lambda x: colorama.Fore.YELLOW + x + colorama.Fore.RESET
RED= lambda x: colorama.Fore.RED + x + colorama.Fore.RESET

In [None]:
# model_name = "meta-llama/Llama-3.2-1B-Instruct" # Tough cookie! Also, requires permissions through HF authentication
model_name = "/u/anp407/Workspace/Huggingface/Qwen/Qwen3-Embedding-0.6B"
# model_name = "Qwen/Qwen3-0.6B"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Attack parameters
batch_size = 512 # Number of samples to optimize over (512 in GCG paper)
search_batch_size = 256 # Number of samples that actually run forward together
top_k = 256 # Number of top tokens to sample from (256 in GCG paper)
steps = 1000 # Total number of optimization steps (100 in GCG paper)
suffix_length = 20 # Length of the suffix to be optimized (20 in GCG paper)
suffix_initial_token = " !" # Initial token repeated for the length of the suffix
system_prompt = "" # System prompt to be prepended to the input
dataset_index = 0 # Index of the sample in the dataset to attack
device_name = "cuda:0"

# Initial suffix
initial_suffix = suffix_initial_token * suffix_length

# Assertions
assert batch_size % search_batch_size == 0, "Batch size must be divisible by search batch size (convenience)"

In [10]:
# Loading model and tokenizer
model = AutoModel.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map=device_name,
    trust_remote_code=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Getting suffix ids
ids_suffix = tokenizer.encode(initial_suffix, return_tensors="pt", add_special_tokens=False).to(model.device)
assert ids_suffix.shape[1] == suffix_length, f"Initial suffix length {ids_suffix.shape[1]} does not match expected length {suffix_length}. {suffix_initial_token} is not a valid initial token."

In [11]:
# Loading dataset
# dataset = load_dataset("walledai/AdvBench", split='train')
import json
dataset = json.load(open("nq_gemini_enriched.json"))

# Getting request and target
# request = dataset[dataset_index]['prompt'] # NOTE: You can change this manually
# target = dataset[dataset_index]['target'] # NOTE: You can change this manually
queries = [dataset[x]['question'] for x in dataset]
incorrect_answers = [dataset[x]['correct answer'] for x in dataset]
correct_answers = [dataset[x]['incorrect answer'] for x in dataset]
adv_docs = [dataset[x]['benign_texts'] for x in dataset] # find a query that give the correct answer when combined with the benign doc
benign_docs = [dataset[x]['adv_texts'] for x in dataset]
# Showing request and target
# print(f"Request is:\
# 	{request}") # Write a script that can exploit vulnerabilities in a software or operating system
# print(f"Target is:\
# 	{target}") # Sure, here is a script that can exploit vulnerabilities in a software or operating system

In [12]:
positions = ["benign"]
lengths = [1,2,3,4,5,6,7,8,9,10]

In [13]:
# Running optimization with GCG
# for 100 samples
# record the best _result

all_result = {}
selected_index = [0, 1, 4, 5, 7, 8, 9, 10, 11, 12]

for position in positions:
    for length in lengths:

        for i, query in enumerate(queries):
            if i not in selected_index:
                continue
            all_result[(i, position, length)] = {}
            
            # Fix variable shadowing and indexing
            current_adv_docs = adv_docs[i] 
            current_benign_docs = benign_docs[i]
            
            # The user code had: adv_docs = adv_docs[i]  # only use the first adv doc for now
            # Assuming they want the first document from the list of documents for this query
            if isinstance(current_adv_docs, list) and len(current_adv_docs) > 0:
                target_adv_docs = [current_adv_docs[0]]
            else:
                target_adv_docs = [current_adv_docs] if isinstance(current_adv_docs, str) else current_adv_docs

            if isinstance(current_benign_docs, list) and len(current_benign_docs) > 0:
                target_benign_docs = current_benign_docs # Use all?
            else:
                target_benign_docs = [current_benign_docs] if isinstance(current_benign_docs, str) else current_benign_docs
            
            if position == "benign":
                text_before = ""
                text_mid = " !" * length
                text_after = query
            elif position == "mid":
                text_before = query[:len(query)//2]
                text_mid = " !" * length
                text_after = query[len(query)//2:]
            elif position == "end":
                text_before = query
                text_mid = " !" * length
                text_after = ""

            ids_before = tokenizer(text_before, return_tensors="pt", add_special_tokens=False).to(model.device)['input_ids']
            ids_mid = tokenizer(text_mid, return_tensors="pt", add_special_tokens=False).to(model.device)['input_ids']
            ids_after = tokenizer(text_after, return_tensors="pt", add_special_tokens=False).to(model.device)['input_ids']
            
            # Initialize suffix ids from ids_mid
            ids_suffix = ids_mid.clone()

            # Pre-compute doc embeddings
            adv_doc_ids = tokenizer(target_adv_docs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)['input_ids']
            with torch.no_grad():
                adv_doc_embeds = model(adv_doc_ids, output_hidden_states=True).last_hidden_state[:,-1,:]
            
            benign_doc_ids = tokenizer(target_benign_docs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)['input_ids']
            with torch.no_grad():
                benign_doc_embeds = model(benign_doc_ids, output_hidden_states=True).last_hidden_state[:,-1,:]

            ids_suffix_best = ids_suffix.clone()
            best_loss = float("inf")
            all_losses = []
            
            for step in tqdm(range(steps), desc="Optimization steps", unit="step"):
                # 1. Gradient Step
                embeds_before = model.get_input_embeddings()(ids_before)
                embeds_after = model.get_input_embeddings()(ids_after)
                
                one_hot = torch.nn.functional.one_hot(ids_suffix, num_classes=model.config.vocab_size).to(model.device, model.dtype)
                one_hot.requires_grad = True
                embeds_suffix = one_hot @ model.get_input_embeddings().weight
                
                full_embeds = torch.cat([embeds_before, embeds_suffix, embeds_after], dim=1)
                
                outputs = model(inputs_embeds=full_embeds, output_hidden_states=True)
                q_embeds = outputs.last_hidden_state[:,-1,:]
                
                sim_adv = torch.cosine_similarity(q_embeds, adv_doc_embeds, dim=-1).min()
                sim_benign = torch.cosine_similarity(q_embeds, benign_doc_embeds, dim=-1).max()
                loss = - sim_adv + sim_benign

                loss.backward()
                gradients = -one_hot.grad
                
                all_losses.append(loss.item())
                if loss.item() < best_loss:
                    best_loss = loss.item()
                    ids_suffix_best = ids_suffix.clone()

                # 2. Candidate Selection
                top_k_tokens = torch.topk(gradients, top_k, dim=-1).indices
                
                sub_positions = torch.randint(0, length, (batch_size,))
                sub_tokens = torch.randint(0, top_k, (batch_size,))
                
                candidate_ids = ids_suffix.clone().repeat(batch_size, 1)
                for idx, (pos, tok_idx) in enumerate(zip(sub_positions, sub_tokens)):
                    candidate_ids[idx, pos] = top_k_tokens[0, pos, tok_idx]
                
                candidate_losses = []
                for slice_start in range(0, batch_size, search_batch_size):
                    slice_end = min(slice_start + search_batch_size, batch_size)
                    ids_slice = candidate_ids[slice_start: slice_end]
                    
                    curr_batch_size = ids_slice.shape[0]
                    batch_ids_before = ids_before.repeat(curr_batch_size, 1)
                    batch_ids_after = ids_after.repeat(curr_batch_size, 1)
                    batch_input_ids = torch.cat([batch_ids_before, ids_slice, batch_ids_after], dim=1)
                    
                    with torch.no_grad():
                        batch_outputs = model(input_ids=batch_input_ids, output_hidden_states=True)
                        batch_q_embeds = batch_outputs.last_hidden_state.mean(dim=1)
                        
                        batch_q_norm = torch.nn.functional.normalize(batch_q_embeds, p=2, dim=1)
                        adv_norm = torch.nn.functional.normalize(adv_doc_embeds, p=2, dim=1)
                        benign_norm = torch.nn.functional.normalize(benign_doc_embeds, p=2, dim=1)
                        
                        batch_sim_adv = torch.mm(batch_q_norm, adv_norm.t())
                        batch_sim_benign = torch.mm(batch_q_norm, benign_norm.t())
                        
                        batch_loss = - batch_sim_adv.min(dim=1).values + batch_sim_benign.max(dim=1).values
                        candidate_losses.extend(batch_loss.tolist())

                best_idx = np.argmin(candidate_losses)
                ids_suffix = candidate_ids[best_idx].unsqueeze(0)
                
                mean_loss = np.mean(candidate_losses)
                # print(f"Step {step + 1}/{steps} | Best loss: {best_loss:.3f} | Current loss: {loss.item():.3f} | Mean loss: {mean_loss:.3f}")
            all_result[(i, position, length)]['best_loss'] = best_loss
            all_result[(i, position, length)]['ids_suffix_best'] = ids_suffix_best

Optimization steps: 100%|██████████| 100/100 [00:18<00:00,  5.37step/s]
Optimization steps: 100%|██████████| 100/100 [00:18<00:00,  5.37step/s]
Optimization steps: 100%|██████████| 100/100 [00:19<00:00,  5.18step/s]
Optimization steps: 100%|██████████| 100/100 [00:19<00:00,  5.18step/s]
Optimization steps: 100%|██████████| 100/100 [00:20<00:00,  4.94step/s]
Optimization steps:   0%|          | 0/100 [00:00<?, ?step/s]
Optimization steps: 100%|██████████| 100/100 [00:17<00:00,  5.65step/s]

Optimization steps: 100%|██████████| 100/100 [00:17<00:00,  5.65step/s]

Optimization steps: 100%|██████████| 100/100 [00:17<00:00,  5.60step/s]

Optimization steps: 100%|██████████| 100/100 [00:19<00:00,  5.02step/s]
Optimization steps: 100%|██████████| 100/100 [00:19<00:00,  5.02step/s]
Optimization steps: 100%|██████████| 100/100 [00:19<00:00,  5.17step/s]
Optimization steps: 100%|██████████| 100/100 [00:19<00:00,  5.17step/s]
Optimization steps: 100%|██████████| 100/100 [00:18<00:00,  5.32step/s]

In [14]:
# store all_result into json
import json
with open('gcg_all_result_begin.json', 'w') as f:
    # change all keys to string
    all_result = {str(k): v for k, v in all_result.items()}
    # change all tensor values to list
    for k in all_result:
        for sub_k in all_result[k]:
            if torch.is_tensor(all_result[k][sub_k]):
                all_result[k][sub_k] = all_result[k][sub_k].cpu().tolist()
    json.dump(all_result, f)