In [1]:
import json
import sys
sys.path.append("./VILA/llava/eval/")

In [2]:
from eval_datasets import VQADataset

In [3]:
import numpy as np

def calculate_vqa_accuracy(predictions, ground_truths):
    """
    Calculate VQA accuracy for a list of predictions and ground truths.

    Args:
        predictions (list of str): List of predicted answers by the model.
        ground_truths (list of list of str): List of lists, where each inner list contains
                                             human-provided ground truth answers for a question.
    
    Returns:
        float: Average VQA accuracy across all examples.
    """
    if len(predictions) != len(ground_truths):
        raise ValueError("Predictions and ground truths must have the same length.")
    
    total_score = 0.0
    
    for pred, truths in zip(predictions, ground_truths):
        # Count votes for each answer in ground truths
        answer_votes = {ans: truths.count(ans) for ans in set(truths)}
        pred_votes = answer_votes.get(pred, 0)
        # Calculate the score
        score = min(1.0, pred_votes / 3.0)
        total_score += score
    
    # Return average accuracy
    return total_score / len(predictions)

# # Example usage
# predictions = ["cat", "dog", "bird"]  # Model predictions
# ground_truths = [
#     ["cat", "dog", "cat", "cat", "bird"],  # Annotator answers for question 1
#     ["dog", "dog", "dog", "cat", "cat"],   # Annotator answers for question 2
#     ["bird", "bird", "cat", "cat"] # Annotator answers for question 3
# ]

# vqa_accuracy = calculate_vqa_accuracy(predictions, ground_truths)
# print(f"VQA Accuracy: {vqa_accuracy:.4f}")


In [4]:
val_image_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/val2014"
val_questions_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_OpenEnded_mscoco_val2014_questions.json"
val_annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_mscoco_val2014_annotations.json"
val_dataset = VQADataset(val_image_dir_path, val_questions_path, val_annotations_path,False, "vqav2")

In [6]:
val_dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x512>,
 'question': 'Where is he looking?',
 'question_id': 262148000,
 'answers': ['down',
  'down',
  'at table',
  'skateboard',
  'down',
  'table',
  'down',
  'down',
  'down',
  'down']}

In [5]:
ground_truths = [val_dataset[i]['answers'] for i in range(5000)]

In [6]:
def get_vqa_acc(res_file):
    with open(res_file,'rb') as f:
        data = json.load(f)
    outputs= [x.lower().strip() for x in data['outputs']]
    grths = ground_truths[:len(outputs)]
    return calculate_vqa_accuracy(outputs, grths)

In [19]:
n_shots = [0,2,4,8]
print("Random in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_old/VILA1.5-13b_{shot}-shot_random-examples.json"
    print(get_vqa_acc(res_file))

Random in-context examples
n-shot: 0
0.7118666666666658
n-shot: 2
0.6099999999999985
n-shot: 4
0.6109999999999982
n-shot: 8
0.5900999999999986


In [22]:
print("RICE in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_old/VILA1.5-13b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.19299999999999973
n-shot: 2
0.37836666666666696
n-shot: 4
0.4372666666666668
n-shot: 8
0.3003666666666672


In [14]:
n_shots = [0,2,4,8]
print("Random in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-13b_{shot}-shot_random-examples.json"
    print(get_vqa_acc(res_file))

Random in-context examples
n-shot: 0
0.34493333333333315
n-shot: 2
0.7311333333333345
n-shot: 4
0.7398666666666674
n-shot: 8
0.7420000000000011


In [15]:
print("RICE in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-13b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.34599999999999986
n-shot: 2
0.6881333333333347
n-shot: 4
0.6793333333333338
n-shot: 8
0.6717333333333342


In [7]:
# Vila-3b
n_shots = [0,2,4,8]
print("Random in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-3b_{shot}-shot_random-examples.json"
    print(get_vqa_acc(res_file))

Random in-context examples
n-shot: 0
0.014133333333333331
n-shot: 2
0.4854000000000001
n-shot: 4
0.6200000000000009
n-shot: 8
0.7208666666666673


In [8]:
# Vila-3b
print("RICE in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-3b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.013733333333333334
n-shot: 2
0.3533333333333331
n-shot: 4
0.43813333333333265
n-shot: 8
0.6136666666666678


In [16]:
print("RICE in-context examples")
n_shots = [0,2,4,8]
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_exp/VILA1.5-13b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.39333333333333337
n-shot: 2
0.6733333333333335
n-shot: 4
0.7
n-shot: 8
0.7033333333333335
