In [34]:
from transformers import BertModel, BertTokenizer
import numpy as np
import torch
from torch.nn import functional as F


In [35]:
from transformers import BertTokenizer, BertForSequenceClassification

class Dummy:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertForSequenceClassification.from_pretrained(model_name)
        self.vocab = self.tokenizer.get_vocab()  # This provides the vocabulary

    def __call__(self, x):
        # Assuming x is a list of strings for simplicity
        inputs = self.tokenizer(x, return_tensors='pt', padding=True, truncation=True)
        outputs = self.model(**inputs)
        return outputs.logits

In [39]:
# Simplex Projection
def simplex_projection(s):
    s_sorted = np.sort(s)
    s_cumsum = np.cumsum(s_sorted)
    rho = np.argmax(s_sorted > (s_cumsum - 1) / np.arange(1, len(s) + 1))
    theta = (s_cumsum[rho] - 1) / (rho + 1)
    return np.maximum(s - theta, 0)

# Entropy Projection
def entropy_projection(s, target_entropy=2):
    mask = s > 0
    center = mask / np.sum(mask)
    R = np.sqrt(np.maximum(0, 1 - target_entropy - 1 / np.sum(mask)))
    if R >= np.linalg.norm(s - center):
        return s
    else:
        return simplex_projection(R / np.linalg.norm(s - center) * (s - center) + center)


In [37]:
def relaxed_one_hot(prompt, tokenizer):
    # Tokenize the string prompt
    token_ids = tokenizer.encode(prompt, add_special_tokens=False)
    vocab_size = tokenizer.vocab_size

    # Create one-hot encoding
    relaxed = np.zeros((len(token_ids), vocab_size))
    for i, token_id in enumerate(token_ids):
        relaxed[i, token_id] = 1

    # Add a small noise
    return relaxed + np.random.uniform(0, 1, relaxed.shape) * 0.1

In [44]:
# Adjusted PGD Attack with debugging prints
def pgd_attack(model, original_prompt, loss_fn, learning_rate, epochs):
    X_t = relaxed_one_hot(original_prompt, model.tokenizer)  # Pass the tokenizer

    best_loss = float('inf')
    best_prompt = None

    X_t_tensor = torch.tensor(X_t, requires_grad=True).float()
    X_t_t2= X_t_tensor.detach()
    optimizer = torch.optim.Adam([X_t_t2], lr=learning_rate)

    for t in range(epochs):
        optimizer.zero_grad()

        decoded_prompts = [model.tokenizer.decode([torch.argmax(X_t_tensor[i]).item()]) for i in range(X_t_tensor.shape[0])]
        logits = model(decoded_prompts)
        target = torch.tensor([1] * logits.shape[0], dtype=torch.long)  # Example target tensor with the same batch size
        loss = loss_fn(logits, target)

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            X_t = X_t_tensor.detach().numpy()  # Detach the tensor to get a new numpy array for the next iteration
            for i in range(X_t.shape[0]):
                X_t[i] = simplex_projection(X_t[i])
                X_t[i] = entropy_projection(X_t[i])

        # Update X_t_tensor with the projected values
        X_t_tensor = torch.tensor(X_t, requires_grad=True).float()

        decoded_prompts = [model.tokenizer.decode([torch.argmax(X_t_tensor[i]).item()]) for i in range(X_t_tensor.shape[0])]
        logits = model(decoded_prompts)
        current_loss = loss_fn(logits, target)
        if current_loss < best_loss:
            best_loss = current_loss
            best_prompt = X_t

    return best_prompt

# Example Usage with debugging prints in projection functions
model = Dummy()
loss_fn = lambda logits, target: F.cross_entropy(logits, target)  # Example loss function

original_prompt = "example prompt for PGD attack"
learning_rate = 0.001  # Adjust learning rate
epochs = 10

# Perform PGD attack
adversarial_prompt = pgd_attack(model, original_prompt, loss_fn, learning_rate, epochs)
print("Adversarial Prompt:", adversarial_prompt)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Adversarial Prompt: [[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]]
