In [None]:
import os
import json
import torch
import logging
import torchvision
from PIL import Image
from tqdm.auto import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["MODELSCOPE_LOG_LEVEL"] = str(logging.ERROR)
torchvision.disable_beta_transforms_warning()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
seed = 42

## Choose Existing Concept to Remove

In [None]:
import datasets

captions_dataset = datasets.load_dataset(
    "json",
    data_files = {"validation": "pruned_captions_val2017.json"},
    split = "validation"
)

In [None]:
from pycocotools.coco import COCO

coco = COCO("instances_val2017.json")

assert set(captions_dataset["image_id"]).issubset(set(coco.getImgIds())) # verify if we are querying the correct dataset

In [None]:
import spacy
import numpy as np
from sentence_transformers import SentenceTransformer

removed_concepts = []
concepts_list = []
nlp = spacy.load("en_core_web_lg")
sentence_similarity_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2").to(device)
category_embeds = np.asarray([sentence_similarity_model.encode(category["name"]) for category in coco.loadCats(coco.getCatIds())])
for prompt in tqdm(captions_dataset["caption"]):
    concepts = [" ".join(word.text for word in phrase if word.pos_ != "PRON") for phrase in nlp(prompt).noun_chunks] # remove pronouns
    concepts = [concept for concept in concepts if len(concept) > 0]
    concept_embeds = np.asarray([sentence_similarity_model.encode(concept) for concept in concepts])
    similarity = sentence_similarity_model.similarity(concept_embeds, category_embeds)
    removed_concept = concepts[similarity.argmax().detach().item() // len(category_embeds)] # find concept with max similarity to listed categories
    
    concepts_list.append(concepts)
    removed_concepts.append(removed_concept)

removed_concepts_dataset = captions_dataset.add_column("removed_concept", removed_concepts)
removed_concepts_dataset = removed_concepts_dataset.add_column("concepts", concepts_list)
removed_concepts_dataset.to_json("pruned_captions_with_removed_concept_val2017.json")

assert removed_concepts_dataset["id"] == captions_dataset["id"] # verify that dataset order is not changed

removed_concepts_dataset

In [None]:
removed_concepts_dataset[0]

## Initialize Questions for DSG

In [None]:
import datasets

removed_concepts_dataset = datasets.load_dataset(
    "json",
    data_files = {"validation": "pruned_captions_with_removed_concept_val2017.json"},
    split = "validation[:4%]"
)

In [None]:
from DSG.dsg.query_utils import generate_dsg
from DSG.dsg.vqa_utils import MPLUG, calc_vqa_score
from DSG.dsg.parse_utils import parse_question_output
from transformers import AutoTokenizer, AutoModelForCausalLM

vqa_model = MPLUG()
vqa_model.pipeline_vqa.use_reentrant = False
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
llm = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", device_map = device, torch_dtype = torch.bfloat16)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm.generation_config.pad_token_id = llm_tokenizer.pad_token_id

def autocomplete(prompt, max_new_tokens = 256, **kwargs):
    inputs = llm_tokenizer([prompt], return_tensors = "pt", padding = True).to(device)
    output_ids = llm.generate(**inputs, generation_config = llm.generation_config, max_new_tokens = max_new_tokens, **kwargs)
    return llm_tokenizer.batch_decode(output_ids[:, inputs.input_ids.size(dim = 1):])[0].rstrip(llm_tokenizer.eos_token)

In [None]:
id2prompts = {i: {"input": caption} for i, caption in enumerate(removed_concepts_dataset["removed_concept"])}

_, id2question_outputs, _ = generate_dsg(id2prompts, generate_fn = autocomplete, verbose = False)

In [None]:
result = {"data": []}
for i in tqdm(id2prompts):
    image = Image.open(f"coco_images/{removed_concepts_dataset[i]['image_id']}.jpg")
    qid2question = parse_question_output(id2question_outputs[i]["output"])
    qid2answer = {qid: vqa_model.vqa(image, question).lower() for qid, question in qid2question.items()}
    result["data"].append({"Removed Concept": removed_concepts_dataset[i]["removed_concept"], "VQA": {"Question": qid2question, "Answer": qid2answer}, "Score": calc_vqa_score(qid2answer)["average_score_without_dependency"]})

with open("eval/coco_removed_concept_dsg.json", "w") as f:
    f.write(json.dumps(result))
    f.close()

In [None]:
dsg_eval_dataset = datasets.load_dataset(
    "json",
    data_files = {"eval": "eval/coco_removed_concept_dsg.json"},
    split = "eval",
    field = "data"
)

print("Sample Eval -", dsg_eval_dataset[0])
print("DSG -", sum(dsg_eval_dataset["Score"]) / len(dsg_eval_dataset))