In [1]:
import torch
import torch.nn as nn
import numpy as np
import time
import pickle
from custom_model import CustomModel
from opt_utils import get_nonascii_toks, token_gradients, sample_control, get_filtered_cands, get_logits, load_model_and_tokenizer, get_prompt, get_primary_activation
from multi_prompt_single_model_all_layers_llama.suffix_manager import SuffixManager
import random
import os
import gc

np.set_printoptions(suppress=True, linewidth=10000)
torch.set_printoptions(sci_mode=False, linewidth=100000, threshold=float('inf'))



In [2]:
model_name = 'llama3_8b'

In [3]:
model_path = f'../loaded_models/{model_name}'

In [4]:
adv_string_init = "!_!_!_!_!_!_!_!_!_!_!_!_!_!_!_!"

total_prompts = 31134

prompt_indices = random.sample(range(total_prompts), 100)

num_steps = 2000
topk = 256
batch_size = 512
allow_non_ascii = False  # you can set this to True to use unicode tokens

In [5]:
print(prompt_indices)

[3816, 4415, 21424, 767, 24817, 17446, 14778, 15796, 16509, 22776, 2561, 30960, 2990, 20188, 16356, 11379, 23249, 10231, 20290, 27294, 13642, 5826, 2458, 24193, 22437, 9452, 7315, 19042, 8021, 19568, 14487, 1965, 7140, 13599, 17203, 14006, 6426, 5004, 10082, 2513, 15553, 13152, 22560, 5572, 26898, 3164, 20823, 3358, 17652, 8739, 2713, 25588, 21165, 26688, 19753, 30567, 4023, 30460, 19345, 8204, 29591, 27478, 4550, 28646, 20722, 24393, 1253, 8027, 10456, 14432, 29927, 15310, 26769, 23599, 2830, 26353, 6779, 6637, 25412, 12443, 18757, 26191, 27499, 4556, 20395, 258, 25944, 21137, 3424, 1405, 16937, 26009, 21891, 12073, 12981, 1008, 27350, 10845, 5124, 1778]


In [6]:
torch.cuda.empty_cache()

In [7]:
# Get the text from the test dataset
texts = [get_prompt(prompt_indices[0])]

In [8]:
model, tokenizer = load_model_and_tokenizer(model_path)
print(model.dtype)

from transformers import AutoModelForCausalLM
model32 = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, device_map='cpu').eval()
print(model32.dtype)


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

torch.bfloat16


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

torch.float32


In [10]:
linear_models = {}

layers = []

for (dirpath, dir_names, filenames) in os.walk(f'../trained_linear_probes_microsoft/{model_name}'):
    layers = [int(dir_name) for dir_name in dir_names]
    break

layers.sort()

for i in layers:
    linear_models[i] = pickle.load(open(f'../trained_linear_probes_microsoft/{model_name}/{i}/model.pickle', 'rb'))

custom_model = CustomModel(model, linear_models)
device = custom_model.base_model.get_input_embeddings().weight.device

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [11]:
def get_primary_activations(prompt_index):

    primary_activations = {
        layer: get_primary_activation(index=prompt_index, model=model_name, layer=layer).to(device) for layer in layers
    }

    return primary_activations

In [12]:
# A list of dict
primary_activations_all_prompts = [get_primary_activations(prompt_indices[0])]

In [13]:
target = torch.tensor([0], dtype=torch.float32)
target = target.to(device)

In [15]:
def check_misclassification(adv_suffix, suffix_manager, primary_activations):
    device = custom_model.base_model.get_input_embeddings().weight.device
    input_ids = suffix_manager.get_input_ids(adv_suffix).unsqueeze(0)
    input_ids = input_ids.to(device)

    attention_mask = torch.ones_like(input_ids).to(device)

    # outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    outputs = model32(input_ids=input_ids.to('cpu'), attention_mask=attention_mask.to('cpu'), output_hidden_states=True)

    labels = []
    probs = []

    for num_layer, linear_model in linear_models.items():
        poisoned_activation = outputs['hidden_states'][num_layer + 1][0, -1].float()
        # delta = (poisoned_activation - primary_activations[num_layer]).detach().cpu().float().numpy().reshape(1, -1)   # gpu version
        delta = (poisoned_activation - primary_activations[num_layer].to('cpu')).detach().cpu().float().numpy().reshape(1, -1)  # cpu version

        label = linear_model.predict(delta)
        prob = linear_model.predict_proba(delta)

        labels.append(label[0].tolist())
        probs.append(prob[0].tolist())

    return labels, probs


In [17]:
def format_probs(probs):
    formatted_probs = []
    for prob_pair in probs:
        formatted_pair = [f"{p:.8f}" for p in prob_pair]
        formatted_probs.append(f"[{formatted_pair[0]}, {formatted_pair[1]}]")
    probs_str = "[" + ", ".join(formatted_probs) + "]"

    return probs_str

def format_losses(losses):
    losses_str = "[" + ", ".join([f'{loss:.8f}' for loss in losses]) + "]"
    return losses_str

In [18]:
def check_probabilities(probs_all_prompts_all_layers):
    mn = 100.0

    for probs_current_prompt_all_layers in probs_all_prompts_all_layers:
        for probs_current_prompt_current_layer in probs_current_prompt_all_layers:
            mn = min(mn, probs_current_prompt_current_layer[0])
    return mn


def percentage_of_successful_prompts_all_layers(probs_all_prompts_all_layers, confidence_threshold=.8):

    cnt = 0

    for probs_current_prompt_all_layers in probs_all_prompts_all_layers:
        mn = 100.0
        for probs_current_prompt_current_layer in probs_current_prompt_all_layers:
            mn = min(mn, probs_current_prompt_current_layer[0])

        if mn >= confidence_threshold:
            cnt += 1

    return cnt / len(probs_all_prompts_all_layers)


def add_next_prompt():
    # Add next prompt
    texts.append(get_prompt(prompt_indices[len(texts)]))

    primary_activations = get_primary_activations(prompt_indices[len(primary_activations_all_prompts)])
    # Add primary activations of this prompt
    primary_activations_all_prompts.append(primary_activations)


In [None]:
suffix_manager = SuffixManager(tokenizer, texts[0], adv_string_init)

not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)

adv_suffix = adv_string_init

start_time = time.time()

last_added = 0

for i in range(0, num_steps + 1):

    gc.collect()

    coordinate_grad_all_prompts = None
    losses_all_prompts_all_layers = []

    input_ids = None

    # A list of lists
    labels_all_prompts_all_layers = []

    # A list of lists of lists
    probs_all_prompts_all_layers = []

    for text, primary_activations_current_prompt in zip(texts, primary_activations_all_prompts):

        # Step 1. Encode user prompt (behavior + adv suffix) as tokens and return token ids.
        suffix_manager = SuffixManager(tokenizer, text, adv_string_init)

        input_ids = suffix_manager.get_input_ids(adv_suffix)
        input_ids = input_ids.to(device)

        # Step 2. Compute Coordinate Gradient
        coordinate_grad_current_prompt, losses_current_prompt_all_layers = token_gradients(custom_model, input_ids, suffix_manager.adv_string_slice,
                                                                    target, primary_activations=primary_activations_current_prompt)

        losses_all_prompts_all_layers.append(losses_current_prompt_all_layers)

        if coordinate_grad_all_prompts is None:
            coordinate_grad_all_prompts = coordinate_grad_current_prompt
        else:
            coordinate_grad_all_prompts += coordinate_grad_current_prompt

        labels_current_prompt_all_layers, probs_current_prompt_all_layers = check_misclassification(adv_suffix, suffix_manager, primary_activations_current_prompt)
        labels_all_prompts_all_layers.append(labels_current_prompt_all_layers)
        probs_all_prompts_all_layers.append(probs_current_prompt_all_layers)

        gc.collect()
        torch.cuda.empty_cache()

    print(f"i: {i}")

    for idx in range(len(labels_all_prompts_all_layers)):
        print(f"losses: {format_losses(losses_all_prompts_all_layers[idx])} labels: {labels_all_prompts_all_layers[idx]} probs: {format_probs(probs_all_prompts_all_layers[idx])}")
    print(f"{adv_suffix}")
    print("-------------------------------------------------------\n")

    if i == num_steps:
        break

    # Step 3. Sample a batch of new tokens based on the coordinate gradient.
    # Notice that we only need the one that minimizes the loss.
    with torch.no_grad():

        # Step 3.1 Slice the input to locate the adversarial suffix.
        adv_suffix_tokens = input_ids[suffix_manager.adv_string_slice].to(device)
        # print(f"adv suffix: {adv_suffix_tokens}   {adv_suffix_tokens.shape}    {[tokenizer.decode(tk.item()) for tk in adv_suffix_tokens]}   after suffix manager")

        # Step 3.2 Select all the possible suffixes after replacements
        # Many suffixes — sometimes all — are filtered in the function get_filtered_cands due to length mismatch
        # after decoding and re-encoding. Therefore, we replace each position with topk tokens in sample_control
        # and randomly select batch_size after filtering

        # Encoded suffixes
        new_adv_suffix_tokens = sample_control(
            adv_suffix_tokens,
            coordinate_grad_all_prompts,
            topk=topk,
            not_allowed_tokens=not_allowed_tokens
        )

        # Step 3.3 This step ensures all adversarial candidates have the same number of tokens.
        # Decoded suffixes
        new_adv_suffix = get_filtered_cands(
            tokenizer,
            new_adv_suffix_tokens,
            new_batch_size=batch_size,
            filter_cand=True,
            curr_control=adv_suffix
        )

        # print([len(tokenizer(x, add_special_tokens=False).input_ids) for x in new_adv_suffix])

        losses_all_prompts_all_layers = None

        for text, primary_activations_current_prompt in zip(texts, primary_activations_all_prompts):

            suffix_manager = SuffixManager(tokenizer, text, adv_string_init)
            input_ids = suffix_manager.get_input_ids(adv_suffix)
            input_ids = input_ids.to(device)

            # Step 3.4 Compute loss on these candidates and take the argmin.
            logits_per_classifier = get_logits(
                custom_model=custom_model,
                tokenizer=tokenizer,
                input_ids=input_ids,
                control_slice=suffix_manager.adv_string_slice,
                primary_activations=primary_activations_current_prompt,
                test_controls=new_adv_suffix,
                batch_size=8 # decrease this number if you run into OOM.
            )
            last_classifier_logits = next(reversed(logits_per_classifier.values()))
            target = target.to(last_classifier_logits.dtype)
            expanded_target = target.expand_as(last_classifier_logits)
            expanded_target = expanded_target.to(last_classifier_logits.device)

            losses_current_prompt_all_layers = None

            for num_layer, logits_current_layer in logits_per_classifier.items():
                losses_current_layer = nn.BCEWithLogitsLoss(reduction='none')(logits_current_layer, expanded_target)

                if losses_current_prompt_all_layers is None:
                    losses_current_prompt_all_layers = losses_current_layer
                else:
                    losses_current_prompt_all_layers += losses_current_layer
                del logits_current_layer

            if losses_all_prompts_all_layers is None:
                losses_all_prompts_all_layers = losses_current_prompt_all_layers
            else:
                losses_all_prompts_all_layers += losses_current_prompt_all_layers
            gc.collect()
            # torch.cuda.empty_cache()

        best_new_adv_suffix_id = losses_all_prompts_all_layers.argmin()
        best_new_adv_suffix = new_adv_suffix[best_new_adv_suffix_id]

        adv_suffix = best_new_adv_suffix

        # print(adv_suffix, tokenizer(adv_suffix, add_special_tokens=False).input_ids, [tokenizer.decode(tk) for tk in tokenizer(adv_suffix, add_special_tokens=False).input_ids], len(tokenizer(adv_suffix, add_special_tokens=False).input_ids))
        # print("----------------------------------------------------------------------------\n")

    if len(texts) < len(prompt_indices):

        if percentage_of_successful_prompts_all_layers(probs_all_prompts_all_layers, confidence_threshold=.6) >= .8:
            # If the attack is successful on 80% or more prompts, add next prompt
            add_next_prompt()
            last_added = i

        if len(texts) >= 5 and i - last_added == 10:
            # If there are already 5 or more prompts and the latest one or any previous one is being recalcitrant
            # and the suffix is not working on it, it might be helpful to add a new prompt
            add_next_prompt()
            last_added = i
        elif len(texts) >= 3 and i - last_added == 20:
            add_next_prompt()
            last_added = i
        elif len(texts) == 2 and i - last_added == 30:
            add_next_prompt()
            last_added = i

end_time = time.time()

print(f"Total time: {end_time - start_time} seconds")