In [None]:
import os
import json
import torch
import logging
import datasets
import torchvision
from PIL import Image
from tqdm.auto import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["MODELSCOPE_LOG_LEVEL"] = str(logging.ERROR)
torchvision.disable_beta_transforms_warning()
device = "cuda:0" if torch.cuda.is_available() else "cpu"

## Load Dataset

In [None]:
images_dataset = datasets.load_dataset(
    "json",
    data_files = {"train": "captions_train2017.json", "validation": "captions_val2017.json"},
    split = "validation",
    field = "images"
)

captions_dataset = datasets.load_dataset(
    "json",
    data_files = {"train": "captions_train2017.json", "validation": "captions_val2017.json"},
    split = "validation",
    field = "annotations"
)

# Print all images that have more/less than 5 captions in the original dataset
##d = {}
##for entry in captions_dataset:
##    if entry["image_id"] not in d:
##        d[entry["image_id"]] = []
##    d[entry["image_id"]].append(entry["caption"])
##for item in d:
##    if len(d[item]) != 5:
##        print("Image ID:", item, "- Number of Captions:", len(d[item]))

In [None]:
import pandas as pd

# Prune dataset so that each image has one caption and each caption has one image
##pruned_captions_dataset = pd.DataFrame(captions_dataset)
##pruned_captions_dataset = pruned_captions_dataset.drop_duplicates(subset = "caption", keep = "first", ignore_index = True)
##pruned_captions_dataset = pruned_captions_dataset.drop_duplicates(subset = "image_id", keep = "first", ignore_index = True)
##pruned_captions_dataset.to_json("pruned_captions_val2017.json", orient = "records")

# Prove bijection for pruned dataset
##assert pruned_captions_dataset["caption"].nunique() == pruned_captions_dataset["image_id"].nunique() == len(pruned_captions_dataset)

# Generate new dataset by taking only the first 4% of those captions to form 200 prompts
captions_dataset = datasets.load_dataset(
    "json",
    data_files = {"validation": "pruned_captions_val2017.json"},
    split = "validation[:4%]"
)

In [None]:
word_count = 0
for caption in captions_dataset["caption"]:
    word_count += len(caption.strip().split(" "))

print("Average Word Count Per Prompt:", word_count / len(captions_dataset))

### Download Images

In [None]:
"""
import urllib

for image in tqdm(images_dataset, desc = "Downloading images"):
    urllib.request.urlretrieve(image["coco_url"], f"coco_images/{image['id']}.jpg")
"""

## Evaluation Metrics

* <b>Alignment:</b> CLIPScore, DSG, VQAScore
* <b>Quality:</b> FID
* <b>Attack Perceptibility:</b> L-Distance

### DSG

In [None]:
from DSG.dsg.query_utils import generate_dsg
from DSG.dsg.vqa_utils import MPLUG, calc_vqa_score
from DSG.dsg.parse_utils import parse_question_output
from transformers import AutoTokenizer, AutoModelForCausalLM

vqa_model = MPLUG()
vqa_model.pipeline_vqa.use_reentrant = False
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
llm = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", device_map = device, torch_dtype = torch.bfloat16)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm.generation_config.pad_token_id = llm_tokenizer.pad_token_id

def autocomplete(prompt, max_new_tokens = 256, **kwargs):
    inputs = llm_tokenizer([prompt], return_tensors = "pt", padding = True).to(device)
    output_ids = llm.generate(**inputs, generation_config = llm.generation_config, max_new_tokens = max_new_tokens, **kwargs)
    return llm_tokenizer.batch_decode(output_ids[:, inputs.input_ids.size(dim = 1):])[0].rstrip(llm_tokenizer.eos_token)

In [None]:
id2prompts = {i: {"input": caption} for i, caption in enumerate(captions_dataset["caption"])}

_, id2question_outputs, _ = generate_dsg(id2prompts, generate_fn = autocomplete, verbose = False)

In [None]:
result = {"data": []}
for i in tqdm(id2prompts):
    image = Image.open(f"coco_images/{captions_dataset[i]['image_id']}.jpg")
    qid2question = parse_question_output(id2question_outputs[i]["output"])
    qid2answer = {qid: vqa_model.vqa(image, question).lower() for qid, question in qid2question.items()}
    result["data"].append({"Prompt": captions_dataset[i]["caption"], "VQA": {"Question": qid2question, "Answer": qid2answer}, "Score": calc_vqa_score(qid2answer)["average_score_without_dependency"]})

with open("eval/coco_captions_dsg.json", "w") as f:
    f.write(json.dumps(result))
    f.close()

In [None]:
dsg_eval_dataset = datasets.load_dataset(
    "json",
    data_files = {"eval": "eval/coco_captions_dsg.json"},
    split = "eval",
    field = "data"
)

print("Sample Eval -", dsg_eval_dataset[0])
print("DSG -", sum(dsg_eval_dataset["Score"]) / len(dsg_eval_dataset))

### VQAScore

In [None]:
from t2v_metrics.t2v_metrics import VQAScore

clip_flant5_score = VQAScore(model = "clip-flant5-xl")

result = {"data": []}
for entry in tqdm(captions_dataset):
    result["data"].append({"Prompt": entry["caption"], "Score": clip_flant5_score(images = [f"coco_images/{entry['image_id']}.jpg"], texts = [entry["caption"]]).detach().cpu().item()})

with open("eval/coco_captions_vqascore.json", "w") as f:
    f.write(json.dumps(result))
    f.close()

In [None]:
vqascore_eval_dataset = datasets.load_dataset(
    "json",
    data_files = {"eval": "eval/coco_captions_vqascore.json"},
    split = "eval",
    field = "data"
)

print("Sample Eval -", vqascore_eval_dataset[0])
print("VQAScore -", sum(vqascore_eval_dataset["Score"]) / len(vqascore_eval_dataset))

### CLIPScore

In [None]:
import numpy as np
from torchmetrics.multimodal.clip_score import CLIPScore

clip_score = CLIPScore(model_name_or_path = "openai/clip-vit-large-patch14").to(device)

result = {"data": []}
for entry in tqdm(captions_dataset):
    image = Image.open(f"coco_images/{entry['image_id']}.jpg").convert("RGB")
    result["data"].append({"Prompt": entry["caption"], "Score": clip_score(torch.tensor(np.asarray(image)).permute(2, 0, 1).to(device), entry["caption"]).detach().cpu().item()})
    
with open("eval/coco_captions_clipscore.json", "w") as f:
    f.write(json.dumps(result))
    f.close()

In [None]:
clip_eval_dataset = datasets.load_dataset(
    "json",
    data_files = {"eval": "eval/coco_captions_clipscore.json"},
    split = "eval",
    field = "data"
)

print("Sample Eval -", clip_eval_dataset[0])
print("CLIPScore -", sum(clip_eval_dataset["Score"]) / len(clip_eval_dataset))