<center>
    <h1>Towards Interpretable Prompts</h1>
    <p>An Experiment Conducted by Oam Patel, Jason Wang, Nikhil Nayak, and Suraj Srinivas</p>
</center>

## Settings

In [4]:
# Path to the folder containing the saved models
BASE_DIR = "models"

# Whether to train new models or load models
TRAIN_MODELS = False

# Whether this is being run on Google Colab or not
COLAB = True

# The device to run the models on ("cpu" or "cuda")
device = "cuda"

# Prompt length
PROMPT_LENGTH = 5

# Model Name
MODEL_NAME = "google/flan-t5-small"

## Imports

In [2]:
!pip install transformers
!pip install -U datasets
!pip install sentencepiece



In [5]:
# Import Requisite Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqModelOutput, BaseModelOutput, Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5PreTrainedModel, T5Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.optimization import Adafactor
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm.notebook import tqdm
import numpy as np
import random
from typing import Optional, Tuple, Union
import warnings
import copy
import os
import gc

# Set Random Seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [6]:
if COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [7]:
def mem_stats():
    '''
    Memory statistics for memory management
    '''
    t = torch.cuda.get_device_properties(0).total_memory / 1024**3
    r = torch.cuda.memory_reserved(0) / 1024**3
    a = torch.cuda.memory_allocated(0) / 1024**3
    print(f"Total Memory: {t:.2f} GB\n"
          f"Reserved Memory: {r:.2f} GB ({(100*(r/t)):.2f}%)\n"
          f"Remaining Memory: {t-r:.2f} GB ({(100*(t-r)/t):.2f}%)\n"
          f"---------------------------------\n"
          f"Allocated Memory: {a:.2f} GB ({(100*(a/t)):.2f}%)\n"
          f"Percent of Reserved Allocated: {(100*(a+1e-9)/(r+1e-9)):.2f}%\n")

In [8]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
mem_stats()

Total Memory: 14.74 GB
Reserved Memory: 0.00 GB (0.00%)
Remaining Memory: 14.74 GB (100.00%)
---------------------------------
Allocated Memory: 0.00 GB (0.00%)
Percent of Reserved Allocated: 100.00%



## Specify Config and Tokenizer
This section loads the configuration and tokenizer for the pre-trained model specified by MODEL_NAME. The configuration contains important parameters about the model architecture, and the tokenizer is used to convert text into a format the model can understand.

In [9]:
config = AutoConfig.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

## Loading and Pre-processing BoolQ Dataset

In [10]:
#Load train and validation datasets
train_dataset = load_dataset("google/boolq")["train"].with_format("torch")
val_dataset = load_dataset("google/boolq")["validation"].with_format("torch")

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/3.69M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9427 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3270 [00:00<?, ? examples/s]

In [11]:
# Select a subset of the datasets
subset_size = 500
train_dataset = train_dataset.shuffle(seed=seed).select(range(subset_size))
val_dataset = val_dataset.shuffle(seed=seed).select(range(subset_size))

In [12]:
# Prune examples > 512-PROMPT_LENGTH tokens
train_dataset = train_dataset.filter(lambda x: len(tokenizer(x["passage"]+" "+x["question"]).input_ids) + PROMPT_LENGTH < config.n_positions)
val_dataset = val_dataset.filter(lambda x: len(tokenizer(x["passage"]+" "+x["question"]).input_ids) + PROMPT_LENGTH < config.n_positions)

# Balance train dataset (there are more yesses than nos)
true_ids = [i for i in range(len(train_dataset)) if train_dataset[i]["answer"].item()]
false_ids = [i for i in range(len(train_dataset)) if not train_dataset[i]["answer"].item()]
true_ids = true_ids[:len(false_ids)]
train_dataset = train_dataset.select(true_ids+false_ids)

Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (524 > 512). Running this sequence through the model will result in indexing errors


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

## Tokenization

This section tokenizes the datasets, converting the data (questions/answers) into numerical representations that the model can process.

In [13]:
def replace_labels(tokenized_output):
    # Replace padding token id's of the labels by -100 so it's ignored by the loss (as advised by https://huggingface.co/docs/transformers/v4.26.1/en/model_doc/t5#training)
    tokenized_output["input_ids"][tokenized_output["input_ids"]==tokenizer.pad_token_id] = -100
    tokenized_output["input_ids"] = tokenized_output["input_ids"].squeeze()
    return tokenized_output

# Tokenize datasets
train_dataset_tokenized = train_dataset.map(lambda x: {"x": tokenizer(x["passage"]+" "+x["question"],padding="max_length",max_length=config.n_positions-PROMPT_LENGTH)})
train_dataset_tokenized = train_dataset_tokenized.map(lambda x: {"y": replace_labels(tokenizer("yes" if x["answer"] else "no",return_tensors="pt"))})

val_dataset_tokenized = val_dataset.map(lambda x: {"x": tokenizer(x["passage"]+" "+x["question"],padding="max_length",max_length=config.n_positions-PROMPT_LENGTH)})
val_dataset_tokenized = val_dataset_tokenized.map(lambda x: {"y": replace_labels(tokenizer("yes" if x["answer"] else "no",return_tensors="pt"))})

yes_token = tokenizer("yes")["input_ids"][0]
no_token = tokenizer("no")["input_ids"][0]

Map:   0%|          | 0/352 [00:00<?, ? examples/s]

Map:   0%|          | 0/352 [00:00<?, ? examples/s]

Map:   0%|          | 0/499 [00:00<?, ? examples/s]

Map:   0%|          | 0/499 [00:00<?, ? examples/s]

In [14]:
# Create PyTorch data loaders for training
train_dl = DataLoader(train_dataset_tokenized, batch_size=8, shuffle=True)
val_dl = DataLoader(val_dataset_tokenized, batch_size=8, shuffle=True)

## Baseline and Fine-Tuning

In [15]:
# Load the pre-trained Language Model
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model = model.to(device)
# model = model.half()
optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [17]:
# Validation
model.eval()
taskLoss = 0.0
nCorrect = 0
nTotal = 0
with torch.no_grad():
    for i, data in enumerate(tqdm(val_dl)):
        output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
        logits = output.logits
        loss = output.loss
        taskLoss += loss.item() * len(data["x"]["input_ids"])
        nTotal += len(data["x"]["input_ids"])
        yes_logits = logits[:,0,yes_token]
        no_logits = logits[:,0,no_token]
        prediction = yes_logits>no_logits
        truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
        nCorrect += sum(truth==prediction)
print(f"Before Training Val Loss: {taskLoss/nTotal} Val Accuracy: {nCorrect/nTotal}")

#Training loop
for epoch in range(1,1+4):
    #Train
    taskLoss = 0.0
    nCorrect = 0
    nTotal = 0
    model.train()
    for i, data in enumerate(tqdm(train_dl)):
        optimizer.zero_grad()
        output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
        logits = output.logits
        loss = output.loss
        loss.backward()
        optimizer.step()

        taskLoss += loss.item() * len(data["x"]["input_ids"])
        yes_logits = logits[:,0,yes_token]
        no_logits = logits[:,0,no_token]
        prediction = yes_logits>no_logits
        truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
        nCorrect += sum(truth==prediction)
        nTotal += len(data["x"]["input_ids"])
    print(f"Epoch {epoch} Train Loss: {taskLoss/nTotal} Train Accuracy: {nCorrect/nTotal}")

    # Validation
    model.eval()
    taskLoss = 0.0
    nCorrect = 0
    nTotal = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(val_dl)):
            output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
            logits = output.logits
            loss = output.loss
            taskLoss += loss.item() * len(data["x"]["input_ids"])
            nTotal += len(data["x"]["input_ids"])
            yes_logits = logits[:,0,yes_token]
            no_logits = logits[:,0,no_token]
            prediction = yes_logits>no_logits
            truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
            nCorrect += sum(truth==prediction)
    print(f"Epoch {epoch} Val Loss: {taskLoss/nTotal} Val Accuracy: {nCorrect/nTotal}")

  0%|          | 0/63 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Before Training Val Loss: 3.1742697022482007 Val Accuracy: 0.5470941662788391


  0%|          | 0/44 [00:00<?, ?it/s]

Epoch 1 Train Loss: 0.6266811632297256 Train Accuracy: 0.5142045617103577


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1 Val Loss: 0.3478514526435034 Val Accuracy: 0.5791583061218262


  0%|          | 0/44 [00:00<?, ?it/s]

Epoch 2 Train Loss: 0.38111833415248175 Train Accuracy: 0.5426136255264282


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2 Val Loss: 0.35174350413626326 Val Accuracy: 0.5010020136833191


  0%|          | 0/44 [00:00<?, ?it/s]

Epoch 3 Train Loss: 0.35222667354074394 Train Accuracy: 0.6136363744735718


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3 Val Loss: 0.33699584532835203 Val Accuracy: 0.5891783833503723


  0%|          | 0/44 [00:00<?, ?it/s]

Epoch 4 Train Loss: 0.2982327557084235 Train Accuracy: 0.7045454978942871


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4 Val Loss: 0.37429537813744707 Val Accuracy: 0.5911823511123657


Results of a training:

The model was trained for 4 epochs

| Epoch | Train Loss | Train Accuracy | Val Loss | Val Accuracy |
|-------|------------|----------------|----------|--------------|
| Before Training | - | - | 3.174 | 0.547 |
| 1     | 0.627      | 0.514          | 0.348    | 0.579        |
| 2     | 0.381      | 0.543          | 0.352    | 0.501        |
| 3     | 0.352      | 0.614          | 0.337    | 0.589        |
| 4     | 0.298      | 0.705          | 0.374    | 0.591        |

the training loss decreased over the epochs, and the training accuracy increased. The validation loss and accuracy are not as stable, this could be caused by overfitting because the training accuracy is significantly higher than the validation accuracy.

## Prompt Tuning with GCG

In [19]:
class PromptedModel(nn.Module):
    def __init__(self, model, prompt_length, vocab_size):
        super().__init__()
        self.model = model
        self.prompt_length = prompt_length
        self.vocab_size = vocab_size
        self.prompt = nn.Parameter(F.one_hot(torch.randint(0,self.vocab_size,(self.prompt_length,)),num_classes=self.vocab_size).float())
    def forward(self,input_ids=None,attention_mask=None,labels=None):
        input_one_hot = torch.cat((F.one_hot(input_ids,num_classes=self.vocab_size).float(),self.prompt.repeat(input_ids.shape[0],1,1)),dim=1)
        if attention_mask is not None:
            attention_mask = torch.cat((attention_mask,torch.ones(attention_mask.shape[0],self.prompt_length)),dim=1)
        return self.model(inputs_embeds=input_one_hot@self.model.shared.weight,attention_mask=attention_mask,labels=labels)

In [20]:
# set up
try:
    del model
except:
    pass
torch.cuda.empty_cache()
torch.cuda.synchronize()
train_dl = DataLoader(train_dataset_tokenized, batch_size=8, shuffle=True)
val_dl = DataLoader(val_dataset_tokenized, batch_size=8, shuffle=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model = PromptedModel(model,PROMPT_LENGTH,config.vocab_size)
model = model.to(device)

In [21]:
k = 5
b = 10
# Validation
model.eval()
taskLoss = 0.0
nCorrect = 0
nTotal = 0
with torch.no_grad():
    for i, data in enumerate(tqdm(val_dl)):
        output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
        logits = output.logits
        loss = output.loss
        taskLoss += loss.item() * len(data["x"]["input_ids"])
        nTotal += len(data["x"]["input_ids"])
        yes_logits = logits[:,0,yes_token]
        no_logits = logits[:,0,no_token]
        prediction = yes_logits>no_logits
        truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
        nCorrect += sum(truth==prediction)
print(f"Before Training Val Loss: {taskLoss/nTotal} Val Accuracy: {nCorrect/nTotal}")

#Training loop
for epoch in range(1,1+4):
    #Train
    taskLoss = 0.0
    nCorrect = 0
    nTotal = 0
    model.train()
    for i, data in enumerate(tqdm(train_dl)):
        model.zero_grad()
        output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
        logits = output.logits
        loss = output.loss
        loss.backward()

        # Update Prompt with GCG
        top_k_candidates = model.prompt.grad.topk(k,dim=1,largest=False).indices.squeeze()
        best_loss = float("inf")
        best_candidate = None
        old_prompt = model.prompt.data.clone().detach()
        for perturbation in range(b):
            new_prompt = old_prompt.clone()
            i = torch.randint(0,model.prompt_length,(1,)).item()
            j = torch.randint(0,k,(1,)).item()
            new_prompt[i] = F.one_hot(top_k_candidates[i][j],num_classes=model.vocab_size).float()
            model.prompt.data = new_prompt.to(device)
            perturbed_loss = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device)).loss
            if perturbed_loss.item()<best_loss:
                best_loss = perturbed_loss.item()
                best_candidate = new_prompt
        model.prompt.data = best_candidate.to(device)

        if i%50==0:
            print(f"{best_loss:.5g}",tokenizer.decode(torch.argmax(model.prompt.data,dim=1)))

        taskLoss += loss.item() * len(data["x"]["input_ids"])
        yes_logits = logits[:,0,yes_token]
        no_logits = logits[:,0,no_token]
        prediction = yes_logits>no_logits
        truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
        nCorrect += sum(truth==prediction)
        nTotal += len(data["x"]["input_ids"])
    print(f"Epoch {epoch} Train Loss: {taskLoss/nTotal} Train Accuracy: {nCorrect/nTotal}")

    # Validation
    model.eval()
    taskLoss = 0.0
    nCorrect = 0
    nTotal = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(val_dl)):
            output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
            logits = output.logits
            loss = output.loss
            taskLoss += loss.item() * len(data["x"]["input_ids"])
            nTotal += len(data["x"]["input_ids"])
            yes_logits = logits[:,0,yes_token]
            no_logits = logits[:,0,no_token]
            prediction = yes_logits>no_logits
            truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
            nCorrect += sum(truth==prediction)
    print(f"Epoch {epoch} Val Loss: {taskLoss/nTotal} Val Accuracy: {nCorrect/nTotal}")

  0%|          | 0/63 [00:00<?, ?it/s]

Before Training Val Loss: 2.5851318661817806 Val Accuracy: 0.5691382884979248


  0%|          | 0/44 [00:00<?, ?it/s]

3.3607 Ard50 Century Spiele Rand
3.4636 mono50 sediment Spiele Rand
2.5252 PlayStationumps ChâteauASAtran
3.3924 ZIPAmusine720 pollutants
Epoch 1 Train Loss: 3.7084557468240913 Train Accuracy: 0.5


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1 Val Loss: 1.9739077965577763 Val Accuracy: 0.48496994376182556


  0%|          | 0/44 [00:00<?, ?it/s]

2.876 findest Kusinepoarte quoted
3.338 SAN Kusine0% occupancy
3.6025 SceneVOzig conveyor exhibitors
3.1282 TrVOzig conveyor exhibitors
3.2594 TrVOzig conveyor rock
3.4293 TrVOintroductory Bearing dataset
3.2924 TrVOintroductory gospel limousine
3.2974 TrVO matin gospel limousine
3.2121 Travisconverge matin gospel limousine
2.937 Travistran matin gospel limousine
2.9018 Hwytran matin gospel limousine
3.07 Travistran matin gospel limousine
Epoch 2 Train Loss: 3.566347062587738 Train Accuracy: 0.4943181872367859


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2 Val Loss: 1.9134314786456152 Val Accuracy: 0.5450901985168457


  0%|          | 0/44 [00:00<?, ?it/s]

3.1509 Vincent ceuxLRUniversitéANC
3.1013 Statistical ceuxLR peutANC
3.4727 Uploaded ceux Feetran responsive
2.7413 flashlight ceux SEtran responsive
2.9493 flashlight° SEtran responsive
3.097 978° SEtran responsive
3.0685 LR160 LEDtran7%
3.1826 Randcali LEDtran pricing
2.8723 Randcali LED automated pricing
Epoch 3 Train Loss: 3.5835046659816396 Train Accuracy: 0.5028409361839294


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3 Val Loss: 1.7407717807498388 Val Accuracy: 0.46893787384033203


  0%|          | 0/44 [00:00<?, ?it/s]

2.9821 limousinecali diagnosis Wach978
2.8272 Trans visitors Subaru Randscalable
3.1527 einer visitors Subaru Randservice
2.5529 outil deloc regression Randservice
3.1835 Ellen deloc regressionusine affordability
3.1273 Maschineintroductory regression Categories866
3.4413 Kondik Prtraglingual
2.8995 owohldik Prtraglingual
3.0729 McCdik Prtraglingual
Epoch 4 Train Loss: 3.640860693021254 Train Accuracy: 0.5056818127632141


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4 Val Loss: 1.8409310951978266 Val Accuracy: 0.4488978087902069


Prompt Tuning with GCG

| Epoch | Train Loss | Train Accuracy | Val Loss | Val Accuracy |
|-------|------------|----------------|----------|--------------|
| Before Training | - | - | 2.585 | 0.569 |
| 1     | 3.708      | 0.500          | 1.974    | 0.485        |
| 2     | 3.566      | 0.494          | 1.913    | 0.545        |
| 3     | 3.584      | 0.503          | 1.741    | 0.469        |
| 4     | 3.641      | 0.506          | 1.841    | 0.449        |

The training loss did not decrease over the ephocs and remained very high and the training accuract remained consistently around 0.50, meaning that the model is not leaning from the training data.

The validation accuracy did not improve and remained somehow constant, but the results obtained are not improved compared to the accuracy from before the training.

## Prompt Tuning with GCG + Perplexity Regularization

In [22]:
class InputGradientModel(nn.Module):
    def __init__(self, model, prompt_length, vocab_size):
        super().__init__()
        self.model = model
        self.prompt_length = prompt_length
        self.vocab_size = vocab_size
        self.prompt = nn.Parameter(F.one_hot(torch.randint(0,self.vocab_size,(self.prompt_length,)),num_classes=self.vocab_size).float())
    def forward(self,input_ids=None,attention_mask=None,labels=None):
        self.prompt.data = F.one_hot(input_ids,num_classes=self.vocab_size).float()
        input_one_hot = self.prompt
        return self.model(inputs_embeds=input_one_hot@self.model.shared.weight,attention_mask=attention_mask,labels=labels)

In [23]:
try:
    del model
except:
    pass
torch.cuda.empty_cache()
torch.cuda.synchronize()
train_dl = DataLoader(train_dataset_tokenized, batch_size=8, shuffle=True)
val_dl = DataLoader(val_dataset_tokenized, batch_size=8, shuffle=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model = PromptedModel(model,PROMPT_LENGTH,config.vocab_size)
model = model.to(device)
perplexity_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
perplexity_model = InputGradientModel(perplexity_model,PROMPT_LENGTH,config.vocab_size)
perplexity_model = perplexity_model.to(device)

In [24]:
k = 5
b = 10
# Validation
model.eval()
taskLoss = 0.0
nCorrect = 0
nTotal = 0
with torch.no_grad():
    for i, data in enumerate(tqdm(val_dl)):
        output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
        logits = output.logits
        loss = output.loss
        taskLoss += loss.item() * len(data["x"]["input_ids"])
        nTotal += len(data["x"]["input_ids"])
        yes_logits = logits[:,0,yes_token]
        no_logits = logits[:,0,no_token]
        prediction = yes_logits>no_logits
        truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
        nCorrect += sum(truth==prediction)
print(f"Before Training Val Loss: {taskLoss/nTotal} Val Accuracy: {nCorrect/nTotal}")

#Training loop
for epoch in range(1,1+4):
    #Train
    taskLoss = 0.0
    nCorrect = 0
    nTotal = 0
    model.train()
    for i, data in enumerate(tqdm(train_dl)):
        model.zero_grad()
        output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
        logits = output.logits
        loss = output.loss
        loss.backward()

        prompt = model.prompt.clone().detach()

        perplexity_model.zero_grad()
        perplexity_loss = perplexity_model(input_ids=prompt.argmax(dim=-1).unsqueeze(0),labels=prompt.argmax(dim=-1).unsqueeze(0)).loss
        perplexity_loss.backward()

        # Update Prompt with GCG
        top_k_candidates = (model.prompt.grad+1*perplexity_model.prompt.grad).topk(k,dim=1,largest=False).indices.squeeze()
        best_loss = float("inf")
        best_candidate = None
        old_prompt = model.prompt.data.clone().detach()
        for perturbation in range(b):
            new_prompt = old_prompt.clone()
            i = torch.randint(0,model.prompt_length,(1,)).item()
            j = torch.randint(0,k,(1,)).item()
            new_prompt[i] = F.one_hot(top_k_candidates[i][j],num_classes=model.vocab_size).float()
            model.prompt.data = new_prompt.to(device)
            perturbed_loss = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device)).loss
            if perturbed_loss.item()<best_loss:
                best_loss = perturbed_loss.item()
                best_candidate = new_prompt
        model.prompt.data = best_candidate.to(device)


        if i%10==0:
            print(f"{best_loss:.5g},{perplexity_loss.item()}",tokenizer.decode(torch.argmax(model.prompt.data,dim=1)))

        taskLoss += loss.item() * len(data["x"]["input_ids"])
        yes_logits = logits[:,0,yes_token]
        no_logits = logits[:,0,no_token]
        prediction = yes_logits>no_logits
        truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
        nCorrect += sum(truth==prediction)
        nTotal += len(data["x"]["input_ids"])
    print(f"Epoch {epoch} Train Loss: {taskLoss/nTotal} Train Accuracy: {nCorrect/nTotal}")

    # Validation
    model.eval()
    taskLoss = 0.0
    nCorrect = 0
    nTotal = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(val_dl)):
            output = model(input_ids=data["x"]["input_ids"].to(device), labels=data["y"]["input_ids"].to(device))
            logits = output.logits
            loss = output.loss
            taskLoss += loss.item() * len(data["x"]["input_ids"])
            nTotal += len(data["x"]["input_ids"])
            yes_logits = logits[:,0,yes_token]
            no_logits = logits[:,0,no_token]
            prediction = yes_logits>no_logits
            truth = (data["y"]["input_ids"][:,0]==yes_token).to(device)
            nCorrect += sum(truth==prediction)
    print(f"Epoch {epoch} Val Loss: {taskLoss/nTotal} Val Accuracy: {nCorrect/nTotal}")

  0%|          | 0/63 [00:00<?, ?it/s]

Before Training Val Loss: 2.7451100879776216 Val Accuracy: 0.5430861711502075


  0%|          | 0/44 [00:00<?, ?it/s]

3.6939,29.294885635375977 </s><pad> sommes<pad> merci
3.7244,13.640069961547852 X</s><pad></s>
3.7355,2.11336088180542 </s>X<unk> </s>
3.6282,5.06926155090332 </s>X<unk> </s>
3.7265,2.745450496673584 </s> <unk> <pad>
3.6144,11.93589973449707 </s> <unk> </s>
3.5256,10.726781845092773 </s> X<pad></s>
3.4465,11.17876148223877 XX<pad></s>
3.9968,2.1732890605926514 XXX </s>
3.3961,4.7957634925842285 X<unk>  </s>
3.77,4.842508316040039 X<unk>  </s>
3.5114,5.958323001861572 X</s><unk>  
Epoch 1 Train Loss: 4.143643346699801 Train Accuracy: 0.5142045617103577


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1 Val Loss: 2.6130777427810945 Val Accuracy: 0.5410821437835693


  0%|          | 0/44 [00:00<?, ?it/s]

3.5432,4.915716648101807 X</s><unk>  
4.0176,4.098322868347168 X</s></s>X 
3.8081,3.9652466773986816 <unk><unk><unk></s></s>
3.5961,13.8421049118042 </s><pad> </s>X
3.4272,13.8421049118042 </s><pad> </s>X
3.2173,22.020145416259766 </s><pad>X<pad> 
3.1666,21.365966796875 </s><pad>X<pad></s>
3.4713,12.462315559387207 </s><pad>XX</s>
3.6468,12.462315559387207 <pad><pad>XX</s>
3.5185,12.416095733642578 <pad></s><unk><unk></s>
3.74,23.040576934814453 <pad></s></s> <pad>
3.2832,16.137128829956055 </s></s>XX
Epoch 2 Train Loss: 4.031478984789415 Train Accuracy: 0.5056818127632141


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2 Val Loss: 2.5865501335006438 Val Accuracy: 0.5470941662788391


  0%|          | 0/44 [00:00<?, ?it/s]

3.59,5.448463439941406 </s></s></s><pad>
3.7339,12.562982559204102 <unk></s></s><pad>
3.7934,10.372774124145508 <pad> </s><unk></s>
3.5029,13.834567070007324 </s></s><pad><unk>
3.8702,6.181929111480713 X</s></s><unk><unk>
3.8823,10.711732864379883 X<unk></s></s><pad>
3.7059,23.21164894104004 X<pad></s><unk><pad>
Epoch 3 Train Loss: 4.009681354869496 Train Accuracy: 0.5


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3 Val Loss: 2.7454682844196387 Val Accuracy: 0.5370741486549377


  0%|          | 0/44 [00:00<?, ?it/s]

3.6672,2.620253801345825 XX</s><unk>X
3.5929,2.620253801345825 XX</s> X
3.8128,10.92829704284668 <unk>X</s><pad>X
3.3999,22.9484806060791 <unk> </s></s><pad>
3.4192,4.015448570251465 </s> </s></s></s>
3.4886,13.107930183410645 </s><pad><pad></s></s>
3.6516,11.300735473632812 <pad>X </s></s>
3.7753,5.2552900314331055 X X<unk>
3.7615,2.2788851261138916 </s> <unk> 
3.4771,2.978564500808716 </s> <unk>X
3.7137,13.018580436706543 </s><pad></s>X
Epoch 4 Train Loss: 4.037677374753085 Train Accuracy: 0.48295456171035767


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4 Val Loss: 2.6048228224675976 Val Accuracy: 0.5691382884979248


Prompt Tuning with GCG + Perplexity Regularization

| Epoch | Train Loss | Train Accuracy | Val Loss | Val Accuracy |
|-------|------------|----------------|----------|--------------|
| Before Training | - | - | 2.745 | 0.543 |
| 1     | 4.144      | 0.514          | 2.613    | 0.541        |
| 2     | 4.031      | 0.506          | 2.587    | 0.547        |
| 3     | 4.010      | 0.500          | 2.745    | 0.537        |
| 4     | 4.038      | 0.483          | 2.605    | 0.569        |

the results are similar to the GCG-only training, adding perplexity regularization with the current configuration did not improve the model's performance.