### Let's try to find a suffix that works on many prompts in attacking the classifier of the last layer of Phi 3 model

In [None]:
import torch
import torch.nn as nn
import numpy as np
import time
import pickle
from custom_model import CustomModel
from opt_utils import get_nonascii_toks, token_gradients, sample_control, get_filtered_cands, get_logits, load_model_and_tokenizer, get_prompt, get_primary_activation
from suffix_manager import SuffixManager
import random

np.set_printoptions(suppress=True, linewidth=10000)
torch.set_printoptions(sci_mode=False, linewidth=100000, threshold=float('inf'))

In [None]:
model_path = 'loaded_models/phi3'

In [None]:
adv_string_init = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
model_id = "microsoft/Phi-3-mini-4k-instruct"

# Still using 2 A5000 24GB GPUs. So, let's keep it small
prompt_indices = random.sample(range(1000), 5)

num_steps = 300
topk = 64
batch_size = 64
allow_non_ascii = False  # you can set this to True to use unicode tokens

In [None]:
print(prompt_indices)

In [None]:
torch.cuda.empty_cache()

In [None]:
# Get the text from the test dataset
texts = [get_prompt(prompt_indices[0])]

In [None]:
model, tokenizer = load_model_and_tokenizer(model_path)

In [None]:
print(model.device)

In [None]:
suffix_manager = SuffixManager(tokenizer, texts[0], adv_string_init)

In [None]:
linear_model = pickle.load(open('task_drift/trained_linear_probes_microsoft/phi3/31/model.pickle', 'rb'))

custom_model = CustomModel(model, linear_model)
device = custom_model.base_model.get_input_embeddings().weight.device

In [None]:
print(device)

In [None]:
primary_activations = get_primary_activation(prompt_indices[0]).unsqueeze(0)
primary_activations = primary_activations.to(device)

In [None]:
target = torch.tensor([0], dtype=torch.float32)
target = target.to(device)

In [None]:
def check_misclassification(adv_suffix, suffix_manager, primary_activation):
    device = custom_model.base_model.get_input_embeddings().weight.device
    input_ids = suffix_manager.get_input_ids(adv_suffix).unsqueeze(0)
    input_ids = input_ids.to(device)

    attention_mask = torch.ones_like(input_ids).to(device)

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    last_token_activation = outputs['hidden_states'][-1][0, -1].float()  # poisoned activation with suffix

    poisoned_activation = last_token_activation

    delta = (poisoned_activation - primary_activation).detach().cpu().float().numpy().reshape(1, -1)

    label = int(linear_model.predict(delta)[0])
    prob = linear_model.predict_proba(delta)[0]

    return label, prob

In [None]:
not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)
adv_suffix = adv_string_init

start_time = time.time()

for i in range(num_steps + 1):

    coordinate_grad = None
    loss = None
    input_ids = None
    labels = []
    probs = []

    for text, primary_activation in zip(texts, primary_activations):

        # Step 1. Encode user prompt (behavior + adv suffix) as tokens and return token ids.
        suffix_manager = SuffixManager(tokenizer, text, adv_string_init)
        input_ids = suffix_manager.get_input_ids(adv_suffix)
        input_ids = input_ids.to(device)

        # Step 2. Compute Coordinate Gradient
        coordinate_grad_, loss_, outputs, one_hot = token_gradients(custom_model, input_ids, suffix_manager.adv_string_slice, target, primary_activation)

        if coordinate_grad is None:
            coordinate_grad = coordinate_grad_
            loss = loss_
        else:
            coordinate_grad += coordinate_grad_
            loss += loss_

        label, prob = check_misclassification(adv_suffix, suffix_manager, primary_activation)
        labels.append(label)
        probs.append(prob)

    print(f"i: {i}  loss: {loss:.10f}   labels: {labels}   probs: {probs}")
    print(adv_suffix)
    print("-------------------------------------------------------\n")

    if i == num_steps:
        break

    # Step 3. Sample a batch of new tokens based on the coordinate gradient.
    # Notice that we only need the one that minimizes the loss.
    with torch.no_grad():

        # Step 3.1 Slice the input to locate the adversarial suffix.
        adv_suffix_tokens = input_ids[suffix_manager.adv_string_slice].to(device)

        # Step 3.2 Randomly sample a batch of replacements.
        # Encoded suffixes
        new_adv_suffix_tokens = sample_control(
            adv_suffix_tokens,
            coordinate_grad,
            batch_size,
            topk=topk
        )

        # Step 3.3 This step ensures all adversarial candidates have the same number of tokens.
        # Decoded suffixes
        new_adv_suffix = get_filtered_cands(
            tokenizer,
            new_adv_suffix_tokens,
            filter_cand=True,
            curr_control=adv_suffix
        )

        losses = None

        for primary_activation in primary_activations:
            # Step 3.4 Compute loss on these candidates and take the argmin.
            logits = get_logits(
                custom_model=custom_model,
                tokenizer=tokenizer,
                input_ids=input_ids,
                control_slice=suffix_manager.adv_string_slice,
                primary_activation=primary_activation,
                test_controls=new_adv_suffix,
                return_ids=False,
                batch_size=12 # decrease this number if you run into OOM.
            )
            target = target.to(logits.dtype)
            expanded_target = target.expand_as(logits)

            losses_ = nn.BCEWithLogitsLoss(reduction='none')(logits, expanded_target)

            if losses is None:
                losses = losses_
            else:
                losses += losses_

        best_new_adv_suffix_id = losses.argmin()
        best_new_adv_suffix = new_adv_suffix[best_new_adv_suffix_id]

        adv_suffix = best_new_adv_suffix

    if len(texts) < len(prompt_indices) and probs[-1][0] >= .8:
        texts.append(get_prompt(prompt_indices[len(texts)]))
        primary_activation = get_primary_activation(prompt_indices[len(primary_activations)]).unsqueeze(0).to(device)
        primary_activations = torch.cat([primary_activations, primary_activation], dim=0)


end_time = time.time()

print(f"Total time: {end_time - start_time} seconds")