In [1]:
import torch
import torch.nn as nn
import numpy as np
import time
import pickle
from custom_model import CustomModel
from opt_utils import get_nonascii_toks, token_gradients, sample_control, get_filtered_cands, get_logits, load_model_and_tokenizer, get_prompt, get_primary_activation
from suffix_manager import SuffixManager
import random
import os
import sys
import gc

np.set_printoptions(suppress=True, linewidth=10000)
torch.set_printoptions(sci_mode=False, linewidth=100000, threshold=float('inf'))



In [2]:
model_name = 'phi3'

In [3]:
model_path = f'./loaded_models/{model_name}'

In [4]:
adv_string_init = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"

total_prompts = 31134

prompt_indices = random.sample(range(total_prompts), 100)

num_steps = 2000
topk = 64
batch_size = 64
allow_non_ascii = False  # you can set this to True to use unicode tokens

In [5]:
print(prompt_indices)

[24430, 26337, 17695, 28410, 6119, 28484, 2677, 29650, 13369, 11212, 18774, 11758, 17262, 4289, 21445, 21371, 752, 6781, 30093, 28516, 6868, 2683, 9393, 10136, 12141, 1787, 1449, 13174, 15420, 19710, 5774, 24411, 15864, 19258, 18733, 26588, 13916, 27059, 18844, 15468, 5298, 9847, 29249, 15865, 28088, 9333, 3780, 16041, 6620, 12880, 18922, 27478, 10249, 29723, 11775, 7246, 14573, 31012, 22969, 21985, 27203, 30011, 3797, 24333, 10064, 19487, 4184, 30534, 26242, 16290, 105, 3047, 21905, 17191, 7586, 27966, 17388, 21644, 22750, 21539, 19889, 22339, 9807, 26909, 30507, 9645, 3860, 16054, 21778, 3305, 27006, 8557, 6368, 16422, 3118, 14525, 13332, 21750, 6984, 4086]


In [6]:
torch.cuda.empty_cache()

In [7]:
# Get the text from the test dataset
texts = [get_prompt(prompt_indices[0])]

In [8]:
model, tokenizer = load_model_and_tokenizer(model_path)

print(model.dtype)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

torch.bfloat16


In [9]:
linear_models = {}

layers = []

for (dirpath, dir_names, filenames) in os.walk(f'./trained_linear_probes_microsoft/{model_name}'):
    layers = [int(dir_name) for dir_name in dir_names]
    break

layers.sort()

for i in layers:
    linear_models[i] = pickle.load(open(f'./trained_linear_probes_microsoft/{model_name}/{i}/model.pickle', 'rb'))

custom_model = CustomModel(model, linear_models)
device = custom_model.base_model.get_input_embeddings().weight.device

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [10]:
def get_primary_activations(prompt_index):

    primary_activations = {
        layer: get_primary_activation(index=prompt_index, model=model_name, layer=layer).to(device) for layer in layers
    }

    return primary_activations

In [11]:
# A list of dict
primary_activations_all_prompts = [get_primary_activations(prompt_indices[0])]

In [12]:
target = torch.tensor([0], dtype=torch.float32)
target = target.to(device)

In [13]:
def check_misclassification(adv_suffix, suffix_manager, primary_activations):
    device = custom_model.base_model.get_input_embeddings().weight.device
    input_ids = suffix_manager.get_input_ids(adv_suffix).unsqueeze(0)
    input_ids = input_ids.to(device)

    attention_mask = torch.ones_like(input_ids).to(device)

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

    labels = []
    probs = []

    for num_layer, linear_model in linear_models.items():
        poisoned_activation = outputs['hidden_states'][num_layer + 1][0, -1].float()
        delta = (poisoned_activation - primary_activations[num_layer]).detach().cpu().float().numpy().reshape(1, -1)

        label = linear_model.predict(delta)
        prob = linear_model.predict_proba(delta)

        labels.append(label[0].tolist())
        probs.append(prob[0].tolist())

    return labels, probs



In [14]:
def format_probs(probs):
    formatted_probs = []
    for prob_pair in probs:
        formatted_pair = [f"{p:.8f}" for p in prob_pair]
        formatted_probs.append(f"[{formatted_pair[0]}, {formatted_pair[1]}]")
    probs_str = "[" + ", ".join(formatted_probs) + "]"

    return probs_str

def format_losses(losses):
    losses_str = "[" + ", ".join([f'{loss:.8f}' for loss in losses]) + "]"
    return losses_str

In [15]:
def check_probabilities(probs_all_prompts_all_layers):
    mn = 100.0

    for probs_current_prompt_all_layers in probs_all_prompts_all_layers:
        for probs_current_prompt_current_layer in probs_current_prompt_all_layers:
            mn = min(mn, probs_current_prompt_current_layer[0])
    return mn


def percentage_of_successful_prompts_all_layers(probs_all_prompts_all_layers, confidence_threshold=.8):

    cnt = 0

    for probs_current_prompt_all_layers in probs_all_prompts_all_layers:
        mn = 100.0
        for probs_current_prompt_current_layer in probs_current_prompt_all_layers:
            mn = min(mn, probs_current_prompt_current_layer[0])

        if mn >= confidence_threshold:
            cnt += 1

    return cnt / len(probs_all_prompts_all_layers)


def add_next_prompt():
    # Add next prompt
    texts.append(get_prompt(prompt_indices[len(texts)]))

    primary_activations = get_primary_activations(prompt_indices[len(primary_activations_all_prompts)])
    # Add primary activations of this prompt
    primary_activations_all_prompts.append(primary_activations)


In [None]:

suffix_manager = SuffixManager(tokenizer, texts[0], adv_string_init)
not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)
adv_suffix = adv_string_init

start_time = time.time()

last_added = 0

for i in range(num_steps + 1):

    gc.collect()

    coordinate_grad_all_prompts = None
    losses_all_prompts_all_layers = []

    input_ids = None

    # A list of lists
    labels_all_prompts_all_layers = []

    # A list of lists of lists
    probs_all_prompts_all_layers = []

    original_stdout = sys.stdout
    sys.stdout = open(os.devnull, 'w')

    for text, primary_activations_current_prompt in zip(texts, primary_activations_all_prompts):

        # Step 1. Encode user prompt (behavior + adv suffix) as tokens and return token ids.
        suffix_manager = SuffixManager(tokenizer, text, adv_string_init)
        input_ids = suffix_manager.get_input_ids(adv_suffix)
        input_ids = input_ids.to(device)

        # Step 2. Compute Coordinate Gradient
        coordinate_grad_current_prompt, losses_current_prompt_all_layers = token_gradients(custom_model, input_ids, suffix_manager.adv_string_slice,
                                                                    target, primary_activations=primary_activations_current_prompt)

        losses_all_prompts_all_layers.append(losses_current_prompt_all_layers)

        if coordinate_grad_all_prompts is None:
            coordinate_grad_all_prompts = coordinate_grad_current_prompt
        else:
            coordinate_grad_all_prompts += coordinate_grad_current_prompt

        labels_current_prompt_all_layers, probs_current_prompt_all_layers = check_misclassification(adv_suffix, suffix_manager, primary_activations_current_prompt)
        labels_all_prompts_all_layers.append(labels_current_prompt_all_layers)
        probs_all_prompts_all_layers.append(probs_current_prompt_all_layers)

        gc.collect()


    sys.stdout = original_stdout

    print(f"i: {i}")

    for idx in range(len(labels_all_prompts_all_layers)):
        print(f"losses: {format_losses(losses_all_prompts_all_layers[idx])} labels: {labels_all_prompts_all_layers[idx]} probs: {format_probs(probs_all_prompts_all_layers[idx])}")
    print(adv_suffix)
    print("-------------------------------------------------------\n")

    if i == num_steps:
        break

    # Step 3. Sample a batch of new tokens based on the coordinate gradient.
    # Notice that we only need the one that minimizes the loss.
    with torch.no_grad():

        # Step 3.1 Slice the input to locate the adversarial suffix.
        adv_suffix_tokens = input_ids[suffix_manager.adv_string_slice].to(device)

        # Step 3.2 Randomly sample a batch of replacements.
        # Encoded suffixes
        new_adv_suffix_tokens = sample_control(
            adv_suffix_tokens,
            coordinate_grad_all_prompts,
            batch_size,
            not_allowed_tokens=not_allowed_tokens,
            topk=topk
        )

        # Step 3.3 This step ensures all adversarial candidates have the same number of tokens.
        # Decoded suffixes
        new_adv_suffix = get_filtered_cands(
            tokenizer,
            new_adv_suffix_tokens,
            filter_cand=True,
            curr_control=adv_suffix
        )

        losses_all_prompts_all_layers = None

        for text, primary_activations_current_prompt in zip(texts, primary_activations_all_prompts):

            suffix_manager = SuffixManager(tokenizer, text, adv_string_init)
            input_ids = suffix_manager.get_input_ids(adv_suffix)
            input_ids = input_ids.to(device)

            # Step 3.4 Compute loss on these candidates and take the argmin.
            logits_per_classifier = get_logits(
                custom_model=custom_model,
                tokenizer=tokenizer,
                input_ids=input_ids,
                control_slice=suffix_manager.adv_string_slice,
                primary_activations=primary_activations_current_prompt,
                test_controls=new_adv_suffix,
                batch_size=8 # decrease this number if you run into OOM.
            )
            last_classifier_logits = next(reversed(logits_per_classifier.values()))
            target = target.to(last_classifier_logits.dtype)
            expanded_target = target.expand_as(last_classifier_logits)

            losses_current_prompt_all_layers = None

            for num_layer, logits_current_layer in logits_per_classifier.items():
                losses_current_layer = nn.BCEWithLogitsLoss(reduction='none')(logits_current_layer, expanded_target)

                if losses_current_prompt_all_layers is None:
                    losses_current_prompt_all_layers = losses_current_layer
                else:
                    losses_current_prompt_all_layers += losses_current_layer
                del logits_current_layer

            if losses_all_prompts_all_layers is None:
                losses_all_prompts_all_layers = losses_current_prompt_all_layers
            else:
                losses_all_prompts_all_layers += losses_current_prompt_all_layers
            gc.collect()
            torch.cuda.empty_cache()

        best_new_adv_suffix_id = losses_all_prompts_all_layers.argmin()
        best_new_adv_suffix = new_adv_suffix[best_new_adv_suffix_id]

        adv_suffix = best_new_adv_suffix

    if len(texts) < len(prompt_indices):

        if percentage_of_successful_prompts_all_layers(probs_all_prompts_all_layers, confidence_threshold=.7) >= .8:
            # If the attack is successful on 80% or more prompts, add next prompt
            add_next_prompt()
            last_added = i

        if len(texts) >= 5 and i - last_added == 10:
            # If there are already 5 or more prompts and one or more are being recalcitrant
            # and the suffix is not working on them, it might be helpful to add a new prompt
            add_next_prompt()
            last_added = i
        elif len(texts) >= 3 and i - last_added == 20:
            add_next_prompt()
            last_added = i
        elif len(texts) == 2 and i - last_added == 30:
            add_next_prompt()
            last_added = i

end_time = time.time()

print(f"Total time: {end_time - start_time} seconds")

i: 0
losses: [5.65032196, 23.89462852, 10.51412106, 29.16087723, 18.99955177] labels: [1, 1, 1, 1, 1] probs: [[0.00351006, 0.99648994], [0.00000000, 1.00000000], [0.00002716, 0.99997284], [0.00000000, 1.00000000], [0.00000001, 0.99999999]]
! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !
-------------------------------------------------------



  return _nested.nested_tensor(


i: 1
losses: [5.35932493, 22.87190437, 8.65739441, 22.86856461, 15.62379646] labels: [1, 1, 1, 1, 1] probs: [[0.00469678, 0.99530322], [0.00000000, 1.00000000], [0.00017405, 0.99982595], [0.00000000, 1.00000000], [0.00000016, 0.99999984]]
! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !!"
-------------------------------------------------------

i: 2
losses: [5.44714355, 22.65389824, 7.75380230, 21.02862930, 13.94448948] labels: [1, 1, 1, 1, 1] probs: [[0.00430289, 0.99569711], [0.00000000, 1.00000000], [0.00043016, 0.99956984], [0.00000000, 1.00000000], [0.00000088, 0.99999912]]
! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !_"
-------------------------------------------------------

i: 3
losses: [5.13147974, 23.60064507, 6.24035358, 18.10927200, 12.14651680] labels: [1, 1, 1, 1, 1] probs: [[0.00590146, 0.99409854], [0.00000000, 1.00000000], [0.00194854, 0.99805146], [0.00000001, 0.99999999], [0.00000532, 0.99999468]]
! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !>::_"
------------------------------------------------