# Please run this notebook on Colab to get the correct outputs.

In [1]:
# %pip install -U transformers datasets xlsxwriter

In [2]:
import random
from datasets import load_dataset
from matplotlib import pyplot as plt
import torch
from transformers import AutoProcessor, AutoModelForPreTraining, LlavaForConditionalGeneration
import pandas as pd
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import os
import xlsxwriter
import io
import requests
import json
import base64
import re

# Load the dataset
dataset = load_dataset("HuggingFaceM4/A-OKVQA")

# Select the first 500 instances with images
val_500_dataset = [dataset['validation'][i] for i in range(500)]
N_SAMPLES = len(val_500_dataset)

In [None]:
# Set up OpenAI API key
from google.colab import userdata
api_key = userdata.get('open_ai_api_key')

# Paths for saving results (modify as needed)
RES_PATH = "./results/"
if not os.path.exists(RES_PATH):
    os.makedirs(RES_PATH)

# Define the file to store the total cost
COST_FILE = "total_cost.txt"

def read_total_cost():
    if os.path.exists(COST_FILE):
        with open(COST_FILE, "r") as file:
            content = file.read().strip()
            return float(content) if not content == "" else 0.0
    else:
        return 0.0

def write_total_cost(cost):
    prev_cost = read_total_cost()
    new_total_cost = prev_cost + cost
    with open(COST_FILE, "w") as file:
        file.write(f"{new_total_cost}")

def calculate_cost(usage, model, verbose=0):
    if model == "gpt-4o-2024-05-13":
        input_cost_per_token = 0.005 / 1000
        output_cost_per_token = 0.015 / 1000
    if model == "gpt-4o-2024-08-06":
        input_cost_per_token = 0.0025 / 1000
        output_cost_per_token = 0.010 / 1000

    input_tokens = usage['prompt_tokens']
    output_tokens = usage['completion_tokens']
    cost = (input_tokens * input_cost_per_token) + (output_tokens * output_cost_per_token)
    if verbose: print(f"The cost incurred is ${cost:.3f}")
    write_total_cost(cost)

def pil_image_to_base64(pil_image, img_format="JPEG"):
    img_buffer = io.BytesIO()
    pil_image.save(img_buffer, format=img_format)
    img_buffer.seek(0)
    return base64.b64encode(img_buffer.read()).decode('utf-8')

# Function to create prompt (you can modify this if needed)
def make_prompt(x):
    return f"{x['question']} Choices: {', '.join(x['choices'])}"

# Define inference functions for LLaVA and GPT-4 models
def inference_llava(model, processor, image, question, with_image=True, mode="qa", max_new_tokens=40):
    if with_image:
        inputs = processor(text=question, images=image, return_tensors="pt").to("cuda")
    else:
        # Create a blank image when with_image is False
        blank_image = Image.new('RGB', (224, 224), color='white')
        inputs = processor(text=question, images=blank_image, return_tensors="pt").to("cuda")
    if mode == "qa":
        outputs = model.generate(**inputs,
                                 num_beams=5,
                                 length_penalty=-1,
                                 max_new_tokens=max_new_tokens)
    elif mode == "rationale":
        outputs = model.generate(**inputs,
                                 num_beams=5,
                                 length_penalty=1.1,
                                 max_new_tokens=max_new_tokens)
    answer = processor.decode(outputs[0], skip_special_tokens=True)
    return answer

def inference_gpt4(image, question, with_image=True, max_tokens=100):
    headers = {
      "Content-Type": "application/json",
      "Authorization": f"Bearer {api_key}"
    }

    if not with_image:
        payload = {
            "model": "gpt-4o-2024-08-06",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": question,
                        }
                    ]
                }
            ],
            "max_tokens": max_tokens
        }
    else:
        payload = {
            "model": "gpt-4o-2024-08-06",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": question,
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{image}"
                            }
                        }
                    ]
                }
            ],
            "max_tokens": max_tokens
        }
    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)

    if response.status_code != 200:
        print(response.json())
        return None
    else:
        usage = response.json()['usage']
        model_name = response.json()['model']
        calculate_cost(usage, model_name)
        return response.json()['choices'][0]['message']['content']

In [None]:
# Load LLaVA model and processor
MODEL_ID_LLAVA = "llava-hf/llava-1.5-7b-hf"
model_llava = LlavaForConditionalGeneration.from_pretrained(MODEL_ID_LLAVA, torch_dtype=torch.bfloat16)
processor_llava = AutoProcessor.from_pretrained(MODEL_ID_LLAVA)
model_llava.to("cuda")

# Define configurations
configs = [
    {'model_name': 'llava', 'with_image': True, 'sheet_name': 'LLaVA-1.5 with image'},
    {'model_name': 'llava', 'with_image': False, 'sheet_name': 'LLaVA-1.5 without image'},
    {'model_name': 'gpt-4o', 'with_image': True, 'sheet_name': 'GPT-4o with image'},
    {'model_name': 'gpt-4o', 'with_image': False, 'sheet_name': 'GPT-4o without image'},
]

# Define save path
SAVE_PATH = "./images/"
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

for idx, x in enumerate(tqdm(val_500_dataset, total=N_SAMPLES)):
    image = x['image']
    image.save(f"{SAVE_PATH}/{idx}.jpg")

# Initialize list to store all results
all_results = []

# Iterate over each configuration
for config in configs:
    model_name = config['model_name']
    with_image = config['with_image']
    sheet_name = config['sheet_name']

    print(f"Processing configuration: {sheet_name}")

    results = []

    if model_name == 'llava':
        # Meta prompt for rationale
        meta_prompt = "Please explain the reasoning behind your answer?"

        model = model_llava
        processor = processor_llava

        for idx, x in enumerate(tqdm(val_500_dataset, total=N_SAMPLES)):
            image = x['image'] if with_image else None

            # Prepare question
            question = make_prompt(x)
            qa_question = f"<image>\nUSER:Question: {question}.\nASSISTANT:"
            qa_answer = inference_llava(model, processor, image, qa_question, with_image=with_image, mode="qa", max_new_tokens=15)
            qa_answer = qa_answer[len(qa_question)-6:].strip().strip('.')

            correct_answer = x['choices'][x['correct_choice_idx']]
            # print(f"Question: {qa_question}\nPredicted Answer: {qa_answer} \tCorrect Answer: {correct_answer}")

            ra_question = f"<image>\nUSER:Question: {question}. Answer: {qa_answer}. {meta_prompt}\nASSISTANT:"
            rationale = inference_llava(model, processor, image, ra_question, with_image=with_image, mode="rationale", max_new_tokens=300)
            rationale = rationale[len(ra_question)-6:].strip()

            results.append({
                'question': question,
                'predicted_answer': qa_answer,
                'correct_answer': correct_answer,
                'is_correct': 1 if qa_answer.lower() == correct_answer.lower() else 0,
                'generated_rationale': rationale,
                'image_path': f"{SAVE_PATH}{idx}.jpg",
            })

    elif model_name == 'gpt-4o':
        for idx, x in enumerate(tqdm(val_500_dataset, total=N_SAMPLES)):
            image = x['image'] if with_image else None

            # Prepare question
            question = make_prompt(x)

            one_shot_prompt = "First, generate a rationale for why you select a given answer for the following question. Follow this with the statement 'Thus, the answer is ' and then provide the answer."
            qa_rationale_question =  f"{one_shot_prompt} Question and choices: {question}."

            image = pil_image_to_base64(image) if with_image else None
            qa_rationale_answer = inference_gpt4(image, qa_rationale_question, with_image=with_image, max_tokens=300)

            # Extract qa_answer from qa_rationale_answer
            if 'Thus, the answer is' in qa_rationale_answer:
                qa_answer = qa_rationale_answer.split('Thus, the answer is')[-1]
                qa_answer = qa_answer.strip(' .\'"`')  # Remove surrounding ., space, and quotes
            else:
                # Fallback if the expected phrase is not found
                qa_answer = qa_rationale_answer.strip().split('\n')[-1]
                qa_answer = qa_answer.strip(' .\'"`')  # Remove surrounding ., space, and quotes

            correct_answer = x['choices'][x['correct_choice_idx']]
            results.append({
                'question': question,
                'predicted_answer': qa_answer,
                'correct_answer': correct_answer,
                'is_correct': 1 if qa_answer.lower() == correct_answer.lower() else 0,
                'generated_rationale': qa_rationale_answer,
                'image_path': f"{SAVE_PATH}{idx}.jpg",
            })

    # Store results in a DataFrame
    df = pd.DataFrame(results)
    all_results.append({'sheet_name': sheet_name, 'df': df})

# Save all results to an Excel file with separate sheets, including images
RES_NAME = "results.xlsx"
with pd.ExcelWriter(os.path.join(RES_PATH, RES_NAME), engine='xlsxwriter') as writer:
    for result in all_results:
        sheet_name = result['sheet_name'][:31]  # Excel sheet name limit is 31 characters
        df = result['df']
        # Write the DataFrame to the worksheet, excluding the 'image_path' column
        df.to_excel(writer, sheet_name=sheet_name, index=False)

        # Access the XlsxWriter workbook and worksheet objects from the DataFrame
        workbook = writer.book
        worksheet = writer.sheets[sheet_name]

        # Set the width of the image column
        worksheet.set_column('H:H', 20)

        # Iterate over the DataFrame to insert images
        for idx, data in df.iterrows():
            # Assuming the image is to be inserted in column 'H'
            cell = f'H{idx + 2}'  # +2 accounts for header row
            image_path = data['image_path']

            # Insert the image into the worksheet
            worksheet.insert_image(cell, image_path, {'x_scale': 0.5, 'y_scale': 0.5})

print(f"Results saved to {os.path.join(RES_PATH, RES_NAME)}")