In [1]:
import torch
import torch.nn as nn
import numpy as np
import time
import pickle
from phi.custom_model import CustomModel
from phi.opt_utils import get_nonascii_toks, token_gradients, sample_control, get_filtered_cands, get_logits, load_model_and_tokenizer, get_training_prompts
from phi.suffix_manager import SuffixManager
from utils.load_file_paths import load_file_paths
import random
import os
import sys
import gc
from constants import PROJECT_ROOT, LAYER_MAP

np.set_printoptions(suppress=True, linewidth=10000)
torch.set_printoptions(sci_mode=False, linewidth=100000, threshold=float('inf'))



In [2]:
model_name = 'phi3'

In [3]:
model_path = f'{PROJECT_ROOT}/loaded_models/{model_name}'

In [4]:
adv_string_init = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"

# Randomly select 100 from the first 100000 prompts
prompt_indices = random.sample(range(100000), 100)

num_steps = 2000
topk = 256
batch_size = 512
allow_non_ascii = False  # you can set this to True to use unicode tokens

In [5]:
print(prompt_indices)

[43350, 96, 52860, 86977, 60302, 12965, 47082, 72076, 16331, 30366, 51102, 14554, 3215, 32600, 55314, 78397, 22593, 73076, 37942, 40109, 44524, 73737, 42329, 90455, 58607, 93449, 16922, 86746, 73328, 61832, 78553, 83022, 43335, 86716, 13173, 279, 80890, 79162, 18886, 59925, 52315, 82825, 42885, 45087, 76587, 52923, 94791, 61494, 68079, 34360, 41862, 72181, 74775, 15706, 20060, 57200, 36893, 82547, 84278, 12738, 26015, 96150, 94351, 20075, 13032, 7626, 71869, 55325, 92428, 39888, 31762, 29393, 63669, 47938, 96832, 71199, 44287, 58632, 10021, 2169, 10967, 94936, 43501, 4088, 63139, 20607, 98324, 90377, 26617, 49850, 99588, 48259, 51193, 68091, 44671, 82861, 59530, 50396, 94252, 68016]


In [6]:
torch.cuda.empty_cache()

In [7]:
# Get the text from the train dataset

training_prompts = get_training_prompts()

texts = [training_prompts[prompt_indices[0]]]

In [8]:
model, tokenizer = load_model_and_tokenizer(model_path)

print(model.dtype)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

torch.bfloat16


In [9]:
linear_models = {}

layers = []

for (dirpath, dir_names, filenames) in os.walk(f'{PROJECT_ROOT}/trained_linear_probes_microsoft/{model_name}'):
    layers = [int(dir_name) for dir_name in dir_names]
    break

layers.sort()

for i in layers:
    linear_models[i] = pickle.load(open(f'{PROJECT_ROOT}/trained_linear_probes_microsoft/{model_name}/{i}/model.pickle', 'rb'))

custom_model = CustomModel(model, linear_models)
device = custom_model.base_model.get_input_embeddings().weight.device

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [10]:
def get_primary_activation(index, model, layer, subset):

    index_in_file = index - int(index / 1000) * 1000

    if subset == 'test':
        filepaths = load_file_paths(f'{PROJECT_ROOT}/data_files/test_poisoned_files_{model}.txt')
    else:
        filepaths = load_file_paths(f'{PROJECT_ROOT}/data_files/train_files_{model}.txt')

    activation_file_index_in_list = 0

    for idx, filepath in enumerate(filepaths):
        if filepath.count(f'_{int(index / 1000) * 1000}_{(int(index / 1000) + 1) * 1000}_') == 1:
            activation_file_index_in_list = idx
            break

    activations = torch.load(f'/home/40456997@eeecs.qub.ac.uk/Reduced Activation/{model}/{subset}/{filepaths[activation_file_index_in_list]}')

    return activations[0][index_in_file][LAYER_MAP[layer]]


def get_primary_activations(prompt_index):

    primary_activations = {
        layer: get_primary_activation(index=prompt_index, model=model_name, layer=layer, subset='training').to(device) for layer in layers
    }

    return primary_activations

In [11]:
# A list of dict
primary_activations_all_prompts = [get_primary_activations(prompt_indices[0])]

In [12]:
target = torch.tensor([0], dtype=torch.float32)
target = target.to(device)

In [13]:
def check_misclassification(adv_suffix, suffix_manager, primary_activations):
    device = custom_model.base_model.get_input_embeddings().weight.device
    input_ids = suffix_manager.get_input_ids(adv_suffix).unsqueeze(0)
    input_ids = input_ids.to(device)

    attention_mask = torch.ones_like(input_ids).to(device)

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

    labels = []
    probs = []

    for num_layer, linear_model in linear_models.items():
        poisoned_activation = outputs['hidden_states'][num_layer + 1][0, -1].float()
        delta = (poisoned_activation - primary_activations[num_layer]).detach().cpu().float().numpy().reshape(1, -1)

        label = linear_model.predict(delta)
        prob = linear_model.predict_proba(delta)

        labels.append(label[0].tolist())
        probs.append(prob[0].tolist())

    return labels, probs



In [14]:
def format_probs(probs):
    formatted_probs = []
    for prob_pair in probs:
        formatted_pair = [f"{p:.8f}" for p in prob_pair]
        formatted_probs.append(f"[{formatted_pair[0]}, {formatted_pair[1]}]")
    probs_str = "[" + ", ".join(formatted_probs) + "]"

    return probs_str

def format_losses(losses):
    losses_str = "[" + ", ".join([f'{loss:.8f}' for loss in losses]) + "]"
    return losses_str

In [15]:
def check_probabilities(probs_all_prompts_all_layers):
    mn = 100.0

    for probs_current_prompt_all_layers in probs_all_prompts_all_layers:
        for probs_current_prompt_current_layer in probs_current_prompt_all_layers:
            mn = min(mn, probs_current_prompt_current_layer[0])
    return mn


def percentage_of_successful_prompts_all_layers(probs_all_prompts_all_layers, confidence_threshold=.8):

    cnt = 0

    for probs_current_prompt_all_layers in probs_all_prompts_all_layers:
        mn = 100.0
        for probs_current_prompt_current_layer in probs_current_prompt_all_layers:
            mn = min(mn, probs_current_prompt_current_layer[0])

        if mn >= confidence_threshold:
            cnt += 1

    return cnt / len(probs_all_prompts_all_layers)


def add_next_prompt():
    # Add next prompt
    texts.append(training_prompts[prompt_indices[len(texts)]])

    primary_activations = get_primary_activations(prompt_indices[len(primary_activations_all_prompts)])
    # Add primary activations of this prompt
    primary_activations_all_prompts.append(primary_activations)


In [None]:

suffix_manager = SuffixManager(tokenizer, texts[0], adv_string_init)
not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)
adv_suffix = adv_string_init

start_time = time.time()

last_added = 0

for i in range(num_steps + 1):

    gc.collect()

    coordinate_grad_all_prompts = None
    losses_all_prompts_all_layers = []

    input_ids = None

    # A list of lists
    labels_all_prompts_all_layers = []

    # A list of lists of lists
    probs_all_prompts_all_layers = []

    original_stdout = sys.stdout
    sys.stdout = open(os.devnull, 'w')

    for text, primary_activations_current_prompt in zip(texts, primary_activations_all_prompts):

        # Step 1. Encode user prompt (behavior + adv suffix) as tokens and return token ids.
        suffix_manager = SuffixManager(tokenizer, text, adv_string_init)
        input_ids = suffix_manager.get_input_ids(adv_suffix)
        input_ids = input_ids.to(device)

        # Step 2. Compute Coordinate Gradient
        coordinate_grad_current_prompt, losses_current_prompt_all_layers = token_gradients(custom_model, input_ids, suffix_manager.adv_string_slice,
                                                                    target, primary_activations=primary_activations_current_prompt)

        losses_all_prompts_all_layers.append(losses_current_prompt_all_layers)

        if coordinate_grad_all_prompts is None:
            coordinate_grad_all_prompts = coordinate_grad_current_prompt
        else:
            coordinate_grad_all_prompts += coordinate_grad_current_prompt

        labels_current_prompt_all_layers, probs_current_prompt_all_layers = check_misclassification(adv_suffix, suffix_manager, primary_activations_current_prompt)
        labels_all_prompts_all_layers.append(labels_current_prompt_all_layers)
        probs_all_prompts_all_layers.append(probs_current_prompt_all_layers)

        gc.collect()


    sys.stdout = original_stdout

    print(f"i: {i}")

    for idx in range(len(labels_all_prompts_all_layers)):
        print(f"losses: {format_losses(losses_all_prompts_all_layers[idx])} labels: {labels_all_prompts_all_layers[idx]} probs: {format_probs(probs_all_prompts_all_layers[idx])}")
    print(adv_suffix)
    print("-------------------------------------------------------\n")

    if i == num_steps:
        break

    # Step 3. Sample a batch of new tokens based on the coordinate gradient.
    # Notice that we only need the one that minimizes the loss.
    with torch.no_grad():

        # Step 3.1 Slice the input to locate the adversarial suffix.
        adv_suffix_tokens = input_ids[suffix_manager.adv_string_slice].to(device)

        # Step 3.2 Randomly sample a batch of replacements.
        # Encoded suffixes
        new_adv_suffix_tokens = sample_control(
            adv_suffix_tokens,
            coordinate_grad_all_prompts,
            batch_size,
            not_allowed_tokens=not_allowed_tokens,
            topk=topk
        )

        # Step 3.3 This step ensures all adversarial candidates have the same number of tokens.
        # Decoded suffixes
        new_adv_suffix = get_filtered_cands(
            tokenizer,
            new_adv_suffix_tokens,
            filter_cand=True,
            curr_control=adv_suffix
        )

        losses_all_prompts_all_layers = None

        for text, primary_activations_current_prompt in zip(texts, primary_activations_all_prompts):

            suffix_manager = SuffixManager(tokenizer, text, adv_string_init)
            input_ids = suffix_manager.get_input_ids(adv_suffix)
            input_ids = input_ids.to(device)

            # Step 3.4 Compute loss on these candidates and take the argmin.
            logits_per_classifier = get_logits(
                custom_model=custom_model,
                tokenizer=tokenizer,
                input_ids=input_ids,
                control_slice=suffix_manager.adv_string_slice,
                primary_activations=primary_activations_current_prompt,
                test_controls=new_adv_suffix,
                batch_size=8 # decrease this number if you run into OOM.
            )
            last_classifier_logits = next(reversed(logits_per_classifier.values()))
            target = target.to(last_classifier_logits.dtype)
            expanded_target = target.expand_as(last_classifier_logits)

            losses_current_prompt_all_layers = None

            for num_layer, logits_current_layer in logits_per_classifier.items():
                losses_current_layer = nn.BCEWithLogitsLoss(reduction='none')(logits_current_layer, expanded_target)

                if losses_current_prompt_all_layers is None:
                    losses_current_prompt_all_layers = losses_current_layer
                else:
                    losses_current_prompt_all_layers += losses_current_layer
                del logits_current_layer

            if losses_all_prompts_all_layers is None:
                losses_all_prompts_all_layers = losses_current_prompt_all_layers
            else:
                losses_all_prompts_all_layers += losses_current_prompt_all_layers
            gc.collect()
            torch.cuda.empty_cache()

        best_new_adv_suffix_id = losses_all_prompts_all_layers.argmin()
        best_new_adv_suffix = new_adv_suffix[best_new_adv_suffix_id]

        adv_suffix = best_new_adv_suffix

    if len(texts) < len(prompt_indices):

        if percentage_of_successful_prompts_all_layers(probs_all_prompts_all_layers, confidence_threshold=.7) >= .8:
            # If the attack is successful on 80% or more prompts, add next prompt
            add_next_prompt()
            last_added = i

        if len(texts) >= 5 and i - last_added == 10:
            # If there are already 5 or more prompts and the latest one or any previous one is being recalcitrant
            # and the suffix is not working on it, it might be helpful to add a new prompt
            add_next_prompt()
            last_added = i
        elif len(texts) >= 3 and i - last_added == 20:
            add_next_prompt()
            last_added = i
        elif len(texts) == 2 and i - last_added == 30:
            add_next_prompt()
            last_added = i

end_time = time.time()

print(f"Total time: {end_time - start_time} seconds")

i: 0
losses: [7.72785997, 27.69341850, 22.53462601, 41.18885040, 44.48118973] labels: [1, 1, 1, 1, 1] probs: [[0.00044070, 0.99955930], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000]]
! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !
-------------------------------------------------------



  return _nested.nested_tensor(


i: 1
losses: [7.38889980, 24.34912872, 21.50238037, 38.87343979, 40.11109924] labels: [1, 1, 1, 1, 1] probs: [[0.00061709, 0.99938291], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000]]
! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !')
-------------------------------------------------------

i: 2
losses: [7.15892744, 24.97650719, 22.22542381, 32.53679657, 35.55488586] labels: [1, 1, 1, 1, 1] probs: [[0.00077632, 0.99922368], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000]]
!est ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !')
-------------------------------------------------------

i: 3
losses: [6.78120327, 24.22472191, 19.64665222, 29.21773148, 32.54939270] labels: [1, 1, 1, 1, 1] probs: [[0.00113454, 0.99886546], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000], [0.00000000, 1.00000000]]
!est ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !)`.')
-------------------------------------------