In [None]:
import os
import json
import torch
import optuna
import torchvision

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torchvision.disable_beta_transforms_warning()
optuna.logging.set_verbosity(optuna.logging.WARNING)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
seed = 42

## T2I Model

In [None]:
from diffusers import DiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"
pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype = torch.bfloat16,
    safety_checker = None
).to(device)
pipeline.set_progress_bar_config(disable = True)

def t2i_model(prompt, *, num_images_per_prompt, num_inference_steps, seed):
    return pipeline(prompt,
                    num_images_per_prompt = num_images_per_prompt,
                    num_inference_steps = num_inference_steps,
                    generator = torch.manual_seed(seed)).images

## VLM Evaluator

In [None]:
import tempfile
from t2v_metrics.t2v_metrics import VQAScore

clip_flant5_score = VQAScore(model = "clip-flant5-xl")

def vlm_evaluator(prompt, adversarial_images):
    temp_files = []
    temp_file_names = []
    for adversarial_image in adversarial_images:
        temp_file = tempfile.NamedTemporaryFile(suffix = ".png")
        adversarial_image.save(temp_file.name)
        temp_files.append(temp_file)
        temp_file_names.append(temp_file.name)

    avg_vqa_score = clip_flant5_score(images = temp_file_names, texts = [prompt]).detach().cpu().mean().item()
    
    for temp_file in temp_files:
        temp_file.close()
        
    return avg_vqa_score

## Adversarial Prompt Optimizer

In [None]:
import string

class Objective:
    def __init__(self, prompt, removed_concept, t2i_model, vlm_evaluator, seed, *, m, k, t):
        self.prompt = prompt
        self.removed_concept = removed_concept
        self.split_prompt = self.prompt.strip().split(" ")
        self.all_pos = len(self.split_prompt) + 1

        self.t2i_model = t2i_model
        self.vlm_evaluator = vlm_evaluator
        
        self.seed = seed
        self.m = m
        self.k = k
        self.t = t

    def __call__(self, trial):
        char_pos_list = [[] for pos in range(self.all_pos)]
        for i in range(self.m):
            char = trial.suggest_categorical(f"char_{i}", string.punctuation)
            pos = trial.suggest_int(f"pos_{i}", 0, len(self.split_prompt)) # low and high are inclusive
            char_pos_list[pos].append(char)

        adversarial_prompt = []
        for pos in range(self.all_pos + len(self.split_prompt)):
            word = " ".join(char_pos_list[pos // 2]) if pos % 2 == 0 else self.split_prompt[pos // 2]
            if word:
                adversarial_prompt.append(word)
        adversarial_prompt = " ".join(adversarial_prompt)
        
        adversarial_images = t2i_model(adversarial_prompt,
                                       num_images_per_prompt = self.k,
                                       num_inference_steps = self.t,
                                       seed = self.seed)
        return vlm_evaluator(self.removed_concept, adversarial_images)

## Hyperparameters

In [None]:
n = 50 # number of permutations/trials
m = 3 # number of punctuations injected
k = 3 # number of images generated per permutation
t = 1 # number of inference steps
sampler = optuna.samplers.NSGAIISampler(seed = seed) # sampler to get approximate best permutation of perturbations

## Prompts

In [None]:
# Prompt template: https://github.com/openai/CLIP/blob/dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1/data/prompts.md?plain=1#L669
concepts = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] # CIFAR10
prompts = [f"a photo of a {concept}" for concept in concepts]

## Pipeline

In [None]:
from tqdm.auto import tqdm

result = {"data": []}
for prompt, removed_concept in tqdm(list(zip(prompts, concepts))):
    # Optimize Adversarial Prompt
    study = optuna.create_study(direction = "minimize", sampler = sampler)
    objective = Objective(prompt, removed_concept, t2i_model, vlm_evaluator, seed,
                          m = m, k = k, t = t)
    study.optimize(objective, n_trials = n, show_progress_bar = False)

    # Retrieve Adversarial Prompt
    split_prompt = prompt.strip().split(" ")
    all_pos = len(split_prompt) + 1
    char_pos_list = [[] for pos in range(all_pos)]
    adversarial_prompt = []
    for i in range(m):
        char_pos_list[study.best_params[f"pos_{i}"]].append(study.best_params[f"char_{i}"])
    for pos in range(all_pos + len(split_prompt)):
        word = " ".join(char_pos_list[pos // 2]) if pos % 2 == 0 else split_prompt[pos // 2]
        if word:
            adversarial_prompt.append(word)
    adversarial_prompt = " ".join(adversarial_prompt)

    # Record Best Adversarial Prompt
    result["data"].append({"Original Prompt": prompt, "Removed Concept": removed_concept, "Adversarial Prompt": adversarial_prompt, "Approx. VQAScore": study.best_value})

In [None]:
for row in result["data"]:
    print(row)