In [1]:
from copy import deepcopy
import colorama
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed, AutoModel
from transformers.cache_utils import DynamicCache
from datasets import load_dataset
import huggingface_hub

# Setting reproducibility
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
set_seed(seed)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# model_name = "meta-llama/Llama-3.2-1B-Instruct" # Tough cookie! Also, requires permissions through HF authentication
model_name = "/u/anp407/Workspace/Huggingface/Qwen/Qwen3-Embedding-0.6B"
# model_name = "Qwen/Qwen3-0.6B"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# Attack parameters
batch_size = 512 # Number of samples to optimize over (512 in GCG paper)
search_batch_size = 256 # Number of samples that actually run forward together
top_k = 256 # Number of top tokens to sample from (256 in GCG paper)
steps = 100 # Total number of optimization steps (500 in GCG paper)
suffix_initial_token = " !" # Initial token repeated for the length of the suffix
device_name = "cuda:0"

# Assertions
assert batch_size % search_batch_size == 0, "Batch size must be divisible by search batch size (convenience)"

In [3]:
# Loading model and tokenizer
model = AutoModel.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map=device_name,
    trust_remote_code=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
# Loading dataset
# dataset = load_dataset("walledai/AdvBench", split='train')
import json
dataset = json.load(open("nq_gemini_enrich_10_gemma_verified_final.json"))

# Getting request and target
queries = [dataset[x]['question'] for x in dataset]

correct_answers = [dataset[x]['correct answer'] for x in dataset]
incorrect_answers = [dataset[x]['incorrect answer'] for x in dataset]
adv_docs = [dataset[x]['filtered_adv_responses'] + dataset[x]['extra_filtered_adv_responses'] for x in dataset]
benign_docs = [dataset[x]['filtered_benign_responses'] + dataset[x]['extra_filtered_benign_responses'] for x in dataset]

In [5]:
positions = ["end"]
lengths = [1,2,4,8,16]
loss_type = 0

In [None]:
# read existing results if any
import os
if os.path.exists('gcg_all_result.jsonl'):
    all_result = {}
    with open('gcg_all_result.jsonl', 'r') as f:
        for line in f:
            entry = json.loads(line)
            for k in entry:
                all_result[eval(k)] = entry[k]
else:
    all_result = {}

In [None]:
# Running optimization with GCG
# for 100 samples
# record the best _result
# early_stop_thred = 0.1
# tolerance_step = 30
succeed_step = -1
from torch import Tensor 
def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    # left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    # if left_padding:
    #     return last_hidden_states[:, -1]
    # else:
    sequence_lengths = attention_mask.sum(dim=1) - 1
    batch_size = last_hidden_states.shape[0]
    return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

for position in positions:
    for length in lengths:
        for i, query in enumerate(queries):
            succeed_step = -1
            if (i, position, length, loss_type) in all_result:
                print(f"Skipping already processed (i={i}, position={position}, length={length}, loss_type={loss_type})")
                continue
            # check the rerun label

            all_result[(i, position, length, loss_type)] = {}
            all_result[(i, position, length, loss_type)]['loss_history'] = []
            all_result[(i, position, length, loss_type)]['ids_suffix_history'] = []
            
            # Fix variable shadowing and indexing
            current_adv_docs = adv_docs[i]
            current_benign_docs = benign_docs[i]
            
            # The user code had: adv_docs = adv_docs[i]  # only use the first adv doc for now
            # Assuming they want the first document from the list of documents for this query
            if isinstance(current_adv_docs, list) and len(current_adv_docs) > 0:
                target_adv_docs = [current_adv_docs[0]]
            else:
                target_adv_docs = [current_adv_docs] if isinstance(current_adv_docs, str) else current_adv_docs

            if isinstance(current_benign_docs, list) and len(current_benign_docs) > 0:
                target_benign_docs = current_benign_docs # Use all?
            else:
                target_benign_docs = [current_benign_docs] if isinstance(current_benign_docs, str) else current_benign_docs
            
            if position == "start":
                text_before = ""
                text_mid = " !" * length
                text_after = query
            elif position == "mid":
                text_before = query[:len(query)//2]
                text_mid = " !" * length
                text_after = query[len(query)//2:]
            elif position == "end":
                text_before = query
                text_mid = " !" * length
                text_after = ""

            ids_before = tokenizer(text_before, return_tensors="pt", add_special_tokens=False).to(model.device)['input_ids']
            ids_mid = tokenizer(text_mid, return_tensors="pt", add_special_tokens=False).to(model.device)['input_ids']
            ids_after = tokenizer(text_after, return_tensors="pt", add_special_tokens=False).to(model.device)['input_ids']
            
            # Initialize suffix ids from ids_mid
            ids_suffix = ids_mid.clone()

            # Pre-compute doc embeddings
            adv_doc_ids = tokenizer(target_adv_docs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)['input_ids']
            with torch.no_grad():
                adv_doc_embeds = last_token_pool(model(adv_doc_ids, output_hidden_states=True).last_hidden_state, adv_doc_ids != tokenizer.pad_token_id)
            
            benign_doc_ids = tokenizer(target_benign_docs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)['input_ids']
            with torch.no_grad():
                benign_doc_embeds = last_token_pool(model(benign_doc_ids, output_hidden_states=True).last_hidden_state, benign_doc_ids != tokenizer.pad_token_id)
            ids_suffix_best = ids_suffix.clone()
            best_loss = float("inf")
            all_losses = []
            
            for step in tqdm(range(steps), desc="Optimization steps", unit="step"):
                # 1. Gradient Step
                embeds_before = model.get_input_embeddings()(ids_before)
                embeds_after = model.get_input_embeddings()(ids_after)
                
                one_hot = torch.nn.functional.one_hot(ids_suffix, num_classes=model.config.vocab_size).to(model.device, model.dtype)
                one_hot.requires_grad = True
                embeds_suffix = one_hot @ model.get_input_embeddings().weight
                
                full_embeds = torch.cat([embeds_before, embeds_suffix, embeds_after], dim=1)
                
                outputs = model(inputs_embeds=full_embeds, output_hidden_states=True)
                q_embeds = last_token_pool(outputs.last_hidden_state, ids_before != tokenizer.pad_token_id)
                
                # plan 0s: simple simliarity difference
                if loss_type == 0:
                    sim_adv = torch.cosine_similarity(q_embeds, adv_doc_embeds, dim=-1).min() # more the better
                    sim_benign = torch.cosine_similarity(q_embeds, benign_doc_embeds, dim=-1).max() # less the better
                    loss = - sim_adv + sim_benign
                # print(f"Step {step}: Loss = {loss.item()}, Sim_adv = {sim_adv.item()}, Sim_benign = {sim_benign.item()}")
                # plan1: contrastive loss contrastive learning loss https://lilianweng.github.io/posts/2021-05-31-contrastive/ 
                elif loss_type == 1:
                    temp = 0.07 # use common temperature https://openreview.net/forum?id=ejHUr4nfHhD
                    sim_adv = torch.cosine_similarity(q_embeds, adv_doc_embeds, dim=-1)
                    sim_benign = torch.cosine_similarity(q_embeds, benign_doc_embeds, dim=-1)
                    logits = torch.cat([sim_adv, sim_benign], dim=0) / temp # sim_adv more the better, sim_benign less the better
                    labels = torch.zeros(logits.size(0), dtype=torch.float).to(model.device)
                    labels[0] += 1  # First half are adv, second half benign
                    loss = torch.nn.CrossEntropyLoss()(logits, labels)
                # plan2: qwen training loss https://arxiv.org/pdf/2506.05176.pdf
                elif loss_type == 2:
                    temp = 0.07
                    sim_adv = torch.cosine_similarity(q_embeds, adv_doc_embeds, dim=-1)/ temp
                    sim_benign = torch.cosine_similarity(q_embeds, benign_doc_embeds, dim=-1)/ temp
                    sim_adv_benign = torch.cosine_similarity(adv_doc_embeds, benign_doc_embeds, dim=-1)/ temp # fixed
                    sim_benign_exp_mean = torch.exp(sim_benign).mean(dim=0) # less the better
                    sim_adv_benign_exp_mean = torch.exp(sim_adv_benign).mean(dim=0) # fixed
                    sim_adv_exp_mean = torch.exp(sim_adv).mean(dim=0) # more the better
                    Z = sim_adv_exp_mean + sim_benign_exp_mean + sim_adv_benign_exp_mean
                    loss = - torch.log(sim_adv_exp_mean / Z) 
                # plan 3: triplet loss
                # margin = 0.1
                # loss = torch.clamp(margin - sim_adv + sim_benign, min=0.0) # want sim_benign - sim_adv = 0.1
                # plan 4: nce loss
                # temp = 0.07
                # logits = torch.mm(q_embeds, torch.cat([adv_doc_embeds, benign_doc_embeds], dim=0).t()) / temp
                # labels = torch.zeros(q_embeds.size(0), dtype=torch.long).to(model.device)
                # loss = torch.nn.CrossEntropyLoss()(logits, labels)
                # plan 5: margin ranking loss
                # margin = 0.1
                # target = torch.ones_like(sim_adv)
                # loss = torch.nn.MarginRankingLoss(margin=margin)(sim_adv.unsqueeze(1), sim_benign.unsqueeze(1), target)

                loss.backward()
                gradients = -one_hot.grad
                
                all_losses.append(loss.item())
                if loss.item() < best_loss:
                    best_loss = loss.item()
                    ids_suffix_best = ids_suffix.clone()

                # 2. Candidate Selection
                top_k_tokens = torch.topk(gradients, top_k, dim=-1).indices
                
                sub_positions = torch.randint(0, length, (batch_size,))
                sub_tokens = torch.randint(0, top_k, (batch_size,))
                
                candidate_ids = ids_suffix.clone().repeat(batch_size, 1)
                for idx, (pos, tok_idx) in enumerate(zip(sub_positions, sub_tokens)):
                    candidate_ids[idx, pos] = top_k_tokens[0, pos, tok_idx]
                
                candidate_losses = []
                for slice_start in range(0, batch_size, search_batch_size):
                    slice_end = min(slice_start + search_batch_size, batch_size)
                    ids_slice = candidate_ids[slice_start: slice_end]
                    
                    curr_batch_size = ids_slice.shape[0]
                    batch_ids_before = ids_before.repeat(curr_batch_size, 1)
                    batch_ids_after = ids_after.repeat(curr_batch_size, 1)
                    batch_input_ids = torch.cat([batch_ids_before, ids_slice, batch_ids_after], dim=1)
                    
                    with torch.no_grad():
                        batch_outputs = model(input_ids=batch_input_ids, output_hidden_states=True)
                        batch_q_embeds = last_token_pool(batch_outputs.last_hidden_state, batch_input_ids != tokenizer.pad_token_id)
                        
                        batch_q_norm = torch.nn.functional.normalize(batch_q_embeds, p=2, dim=1)
                        adv_norm = torch.nn.functional.normalize(adv_doc_embeds, p=2, dim=1)
                        benign_norm = torch.nn.functional.normalize(benign_doc_embeds, p=2, dim=1)
                        
                        batch_sim_adv = torch.mm(batch_q_norm, adv_norm.t())
                        batch_sim_benign = torch.mm(batch_q_norm, benign_norm.t())
                        
                        batch_loss = - batch_sim_adv.min(dim=1).values + batch_sim_benign.max(dim=1).values
                        candidate_losses.extend(batch_loss.tolist())

                best_idx = np.argmin(candidate_losses)
                # print(f"Step {step}: Best Loss = {candidate_losses[best_idx]}")
                ids_suffix = candidate_ids[best_idx].unsqueeze(0)
                
                min_loss = np.min(candidate_losses)
                # if the candidate_loss is less than 0, then success and break.
                if min_loss < 0:
                    succeed_step = step
                
                all_result[(i, position, length, loss_type)]['loss_history'].append(min_loss.item())
                all_result[(i, position, length, loss_type)]['ids_suffix_history'].append(ids_suffix.cpu().numpy())
                all_result[(i, position, length, loss_type)]['succeed_step'] = succeed_step
                
                # if len(all_result[(i, position, length)]['loss_history']) > tolerance_step:
                #     recent_losses = all_result[(i, position, length)]['loss_history'][-tolerance_step:]
                #     if (max(recent_losses) - min(recent_losses))/ max(recent_losses) < early_stop_thred:
                #         print(f"Early stopping at step {step} due to minimal loss improvement.")
                #         break
            
                del embeds_before, embeds_suffix, embeds_after, full_embeds, outputs, one_hot, gradients, candidate_ids, candidate_losses
                torch.cuda.empty_cache()
                    
            all_result[(i, position, length, loss_type)]['best_loss'] = best_loss
            all_result[(i, position, length, loss_type)]['ids_suffix_best'] = ids_suffix_best.cpu().numpy()
            # save the result
            with open(f'gcg_all_result.jsonl', 'a') as f:
                for sub_k in all_result[(i, position, length, loss_type)]:
                    if torch.is_tensor(all_result[(i, position, length, loss_type)][sub_k]):
                        all_result[(i, position, length, loss_type)][sub_k] = all_result[(i, position, length, loss_type)][sub_k].cpu().tolist()
                    if isinstance(all_result[(i, position, length, loss_type)][sub_k], np.ndarray):
                        all_result[(i, position, length, loss_type)][sub_k] = all_result[(i, position, length, loss_type)][sub_k].tolist()
                    # if list
                    if isinstance(all_result[(i, position, length, loss_type)][sub_k], list):
                        for idx, item in enumerate(all_result[(i, position, length, loss_type)][sub_k]):
                            if torch.is_tensor(item):
                                all_result[(i, position, length, loss_type)][sub_k][idx] = item.cpu().tolist()
                            if isinstance(item, np.ndarray):
                                all_result[(i, position, length, loss_type)][sub_k][idx] = item.tolist()
                f.write(json.dumps({str((i, position, length, loss_type)): all_result[(i, position, length, loss_type)]}) + '\n')

Optimization steps:  14%|█▍        | 14/100 [00:02<00:16,  5.30step/s]