In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
%%capture
!pip install datasets
!pip install accelerate
!pip install xlsxwriter

In [None]:
import random
from datasets import load_dataset

dataset = load_dataset("HuggingFaceM4/A-OKVQA")

In [None]:
from matplotlib import pyplot as plt


def imshow(img, ax=None, caption=""):
    if ax is None:
        fig, ax = plt.subplots()
    ax.imshow(img)
    ax.set_title(caption)
    ax.axis("off")
    return ax

In [None]:
def plot_images_with_captions(images, captions):
    fig, axes = plt.subplots(nrows=len(images), ncols=1, figsize=(10, 10))
    for i, (img, caption) in enumerate(zip(images, captions)):
        ax = imshow(img, axes[i])
        ax.set_title(caption)
    plt.show()


images = [dataset["train"][i]["image"] for i in range(2)]
captions = [dataset["train"][i]["question"] for i in range(2)]
plot_images_with_captions(images, captions)

# Load BLIP model 🏋

In [None]:
# from PIL import Image
# import requests
# from transformers import AutoProcessor, BlipForQuestionAnswering

# model = BlipForQuestionAnswering.from_pretrained("Salesforce/instructblip-flan-t5-xxl")
# processor = AutoProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xxl")
# --------------------------------------------#

import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
# model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xxl", device_map="auto")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
).to("cuda")

In [None]:
# prompt: get total model parameters in human readable format

total_params = sum(p.numel() for p in model.parameters())
human_readable_params = f"{total_params / 1e6:.2f}M"
print(f"Total Parameters: {human_readable_params}")

# running inference on A-OKVQA 🏃

In [None]:
from typing import List, Dict, Any

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
def inference(image, question, mode="qa", hyperparams=None):
    # inputs = processor(images=image, text=make_prompt({'question': question}), return_tensors="pt").to("cuda", torch.float16)
    inputs = processor(images=image, text=question, return_tensors="pt").to(
        "cuda", torch.float16
    )
    if mode == "qa":
        outputs = model.generate(**inputs, num_beams=5, length_penalty=-1)
    elif mode == "rationale":
        outputs = model.generate(
            **inputs, num_beams=5, length_penalty=1  # choose from [1, 1.5, 2]
        )

    # Decode and print the answer
    answer = processor.decode(outputs[0], skip_special_tokens=True)
    return answer


def make_prompt(x):
    """made so we can do some preprocessing editing to the question. left for the future"""
    return f"{x['question']} Choices: {', '.join(x['choices'])}"

In [None]:
N = len(dataset["train"])
idx = random.randint(0, N)
# idx = 1605 # <---- CHANGE THIS OR UNCOMMENT THE ABOVE RANDOM IDX SELECTION


x = dataset["train"][idx]
image = x["image"]
question = x["question"]
question = make_prompt(x)  # f"{x['question']} Choices: {str(x['choices'])}"

# run inference
answer = inference(image, question, mode="qa")

imshow(image)
print(f"Example: {idx}")

# print(f"question: {question} \nanswer: {answer} \ncorrect answer: {x['direct_answers']}\nrationale: {x['rationales']} \n correct_choice: {x['choices'][x['correct_choice_idx']]}")
print(f"question: {question}")
print(f"correct_choice: {x['choices'][x['correct_choice_idx']]}")
print(f"answer: {answer}")
print(f"rationale: {x['rationales']}")
print(f"correct answer: {x['direct_answers']}")

print("----------------------------")
print("[FLAN-T5 RATIONALE]")
# run rationale inference
ra_question = "Question: " + question + ". Answer: " + answer + ". " + "Explain?"
answer = inference(image, ra_question, mode="rationale")
print(f"question: {ra_question} \n rationale: {answer}")

print("----------------------------")

# Helper methods

In [None]:
meta_prompts = [
    "Explain why?",
    "Explain your answer.",
    "Why?",
    "How did you decide?",
    "What supports your answer?",
    "Reasoning?",
    "Why this answer?",
    "What evidence?",
    "Your reasoning?",
    "Explain, please.",
    "Why or how?",
    "Any justification?",
    "Elaborate briefly.",
    "Quick rationale?",
    "Why say that?",
    "How come?",
    "Evidence?",
    "Brief explanation?",
    "Why think so?",
    "Rationale?",
    "Explain briefly.",
    ####### BIGGER ONES #######
    "Please explain the reasoning behind your answer.",
    "Can you provide a rationale for your response?",
    "What evidence from the image supports your answer?",
    "How does the information in the image lead you to your conclusion?",
    "Could you elaborate on how you derived your answer from the visual content?",
    "What aspects of the image inform your answer?",
    "Can you break down the thought process that led to your answer?",
    "In what way does the image content justify your response?",
    "How do you interpret the visual information to arrive at your answer?",
    "What visual clues in the image guided your answer?",
    "Could you describe the link between the image details and your answer?",
    "How does the image context support your reasoning?",
    "Can you analyze the image and explain how it leads to your answer?",
    "What in the image makes you say that?",
    "How do you justify your answer based on the image analysis?",
    "Please provide a detailed explanation of your answer using evidence from the image.",
    "Can you connect your answer to specific elements in the image?",
    "What reasoning process did you follow based on the visual data?",
    "How does the visual content influence your answer?",
    "Can you detail the rationale behind your answer with references to the image?",
]

In [None]:
hyperparams = {
    "qa": {
        "num_beams": [5],
        "length_penalty": [-1],
    },
    "rationale": {
        "num_beams": [5],
        "length_penalty": [1, 1.5, 2],
        "meta_prompts": meta_prompts,
    },
}

In [None]:
import itertools


def dict_product(d):
    keys = d.keys()
    values = (d[key] for key in keys)
    for instance in itertools.product(*values):
        yield dict(zip(keys, instance))


def recursive_dict_product(d):
    for key, value in d.items():
        if isinstance(value, dict):
            d[key] = list(recursive_dict_product(value))
    return dict_product(d)


x = list(recursive_dict_product(hyperparams))
x

# Running sweep on selected examples
Select the idx from above and pass it in idx variable below

interesting idxs to try:
- 1605 : the bridge and train one

In [None]:
# run sweep
N = len(dataset["train"])
# idx = random.randint(0, N)
idx = 11209  # <--- CHANGE THIS TO WHATEVER IDX YOU WANT


x = dataset["train"][idx]
image = x["image"]
question = x["question"]

results = []
question = make_prompt(x)  # f"{x['question']} Choices: {str(x['choices'])}"

print(f"Example: {idx}")
imshow(image)
# save image
image.save(f"/content/drive/MyDrive/<hidden>/img/{idx}.jpg")

grid = list(recursive_dict_product(hyperparams))

for comb in grid:
    # run inference
    prompt = comb["rationale"]["meta_prompts"]
    qa_answer = inference(image, question, mode="qa", hyperparams=comb)
    # print(f"question: {question} \nanswer: {answer} \ncorrect answer: {x['direct_answers']}\nrationale: {x['rationales']} \n correct_choice: {x['choices'][x['correct_choice_idx']]}")
    print(f"question: {question}")
    print(f"correct_choice: {x['choices'][x['correct_choice_idx']]}")
    print(f"answer: {answer}")
    print(f"rationale: {x['rationales']}")
    print(f"correct answer: {x['direct_answers']}")

    print("[FLAN-T5 RATIONALE]")
    # run rationale inference
    print(f"prompt: {prompt}")
    ra_question = "Question: " + question + ". Answer: " + qa_answer + ". " + prompt
    answer = inference(image, ra_question, mode="rationale")
    print(f"question: {ra_question} \n rationale: {answer}")

    results.append(
        {
            "question": question,
            "correct_choice": x["choices"][x["correct_choice_idx"]],
            "qa_answer": qa_answer,
            "rationale": x["rationales"],
            "direct_answer": x["direct_answers"],
            "prompt": prompt,
            "generated_rationale": answer,
        }
    )
    print("-" * 100)

## save
import pandas as pd

RES_PATH = "/content/drive/MyDrive/<hidden>/results/"

import os

if not os.path.exists(RES_PATH):
    os.makedirs(RES_PATH)
    os.makedirs(RES_PATH + "img\\")
    os.makedirs(RES_PATH + "result\\")

RES_NAME = f"{idx}_num_beams_{5}_length_penalty_{1}.xlsx"
df = pd.DataFrame(results)
writer = pd.ExcelWriter(RES_PATH + RES_NAME, engine="xlsxwriter")
# Convert the dataframe to an XlsxWriter Excel object.
df.to_excel(writer, sheet_name="Sheet1")

# Get the xlsxwriter workbook and worksheet objects.
workbook = writer.book
worksheet = writer.sheets["Sheet1"]

# Insert an image.
worksheet.insert_image("R1", f"/content/drive/MyDrive/<hidden>/img/{idx}.jpg")

# Close the Pandas Excel writer and output the Excel file.
writer.save()

# Choosing examples

In [None]:
N = len(dataset["train"])
idx = random.randint(0, N)
# idx = 100
x = dataset["train"][idx]
image = x["image"]
question = x["question"]
question = make_prompt(x)  # f"{x['question']} Choices: {str(x['choices'])}"

# run inference
answer = inference(image, question, mode="qa")

imshow(image)
print(f"Example: {idx}")

# print(f"question: {question} \nanswer: {answer} \ncorrect answer: {x['direct_answers']}\nrationale: {x['rationales']} \n correct_choice: {x['choices'][x['correct_choice_idx']]}")
print(f"question: {question}")
print(f"correct_choice: {x['choices'][x['correct_choice_idx']]}")
print(f"answer: {answer}")
print(f"rationale: {x['rationales']}")
print(f"correct answer: {x['direct_answers']}")

print("----------------------------")
print("[FLAN-T5 RATIONALE]")
# run rationale inference
ra_question = "Question: " + question + ". Answer: " + answer + ". " + "Explain?"
answer = inference(image, ra_question, mode="rationale")
print(f"question: {ra_question} \n rationale: {answer}")

print("----------------------------")