In [11]:
import json
import sys
import pandas as pd

sys.path.append("./VILA_codes/llava/eval/")

In [2]:
from eval_datasets import VQADataset, TextVQADataset

In [3]:
import numpy as np

def calculate_vqa_accuracy(predictions, ground_truths):
    """
    Calculate VQA accuracy for a list of predictions and ground truths.

    Args:
        predictions (list of str): List of predicted answers by the model.
        ground_truths (list of list of str): List of lists, where each inner list contains
                                             human-provided ground truth answers for a question.
    
    Returns:
        float: Average VQA accuracy across all examples.
    """
    if len(predictions) != len(ground_truths):
        raise ValueError("Predictions and ground truths must have the same length.")
    
    total_score = 0.0
    
    for pred, truths in zip(predictions, ground_truths):
        # Count votes for each answer in ground truths
        answer_votes = {ans: truths.count(ans) for ans in set(truths)}
        pred_votes = answer_votes.get(pred, 0)
        # Calculate the score
        score = min(1.0, pred_votes / 3.0)
        total_score += score
    
    # Return average accuracy
    return total_score / len(predictions)

# # Example usage
# predictions = ["cat", "dog", "bird"]  # Model predictions
# ground_truths = [
#     ["cat", "dog", "cat", "cat", "bird"],  # Annotator answers for question 1
#     ["dog", "dog", "dog", "cat", "cat"],   # Annotator answers for question 2
#     ["bird", "bird", "cat", "cat"] # Annotator answers for question 3
# ]

# vqa_accuracy = calculate_vqa_accuracy(predictions, ground_truths)
# print(f"VQA Accuracy: {vqa_accuracy:.4f}")


In [4]:
val_image_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/val2014"
val_questions_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_OpenEnded_mscoco_val2014_questions.json"
val_annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_mscoco_val2014_annotations.json"
val_dataset = VQADataset(val_image_dir_path, val_questions_path, val_annotations_path,False, "vqav2")

In [5]:
val_dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x512>,
 'question': 'Where is he looking?',
 'question_id': 262148000,
 'answers': ['down',
  'down',
  'at table',
  'skateboard',
  'down',
  'table',
  'down',
  'down',
  'down',
  'down']}

In [6]:
ground_truths = [val_dataset[i]['answers'] for i in range(5000)]

In [7]:
def get_vqa_acc(res_file):
    with open(res_file,'rb') as f:
        data = json.load(f)
    outputs= [x.lower().strip() for x in data['outputs']]
    grths = ground_truths[:len(outputs)]
    return calculate_vqa_accuracy(outputs, grths)

### Generate csvs

In [13]:
def generate_vqa_accuracy_csv(models, n_shots, strategies, file_template, output_dir):
    for strategy in strategies:
        strategy_str = "_random-examples" if strategy=="random" else ""
        for model in models:
            results = []
            # print(f"Processing model: {model}")
            for shot in n_shots:
                # print(f"n-shot: {shot}")
                res_file = file_template.format(model=model, shot=shot, strategy_str=strategy_str)
                try:
                    accuracy = get_vqa_acc(res_file)
                    results.append({"n_shot": shot, "vqa_accuracy": accuracy})
                except Exception as e:
                    print(f"Error processing file {res_file}: {e}")
                    results.append({"n_shot": shot, "vqa_accuracy": None})  # Handle missing data gracefully
    
            # Create a DataFrame and save it as a CSV file
            df = pd.DataFrame(results)
            output_csv_path = f"{output_dir}/coco_{model}_{strategy}.csv"
            df.to_csv(output_csv_path, index=False)
            print(f"{output_csv_path}")


In [14]:
# Example usage
models = ['VILA1.5-3b', 'VILA1.5-13b']
strategies = ["random", "rice"]
n_shots = [0, 2, 4, 8]
file_template = "/home/asureddy_umass_edu/cs682/VILA_codes/results/vqa/{model}_{shot}-shot{strategy_str}.json"
output_dir = "/home/asureddy_umass_edu/cs682/metrics_results/vqa/"

generate_vqa_accuracy_csv(models, n_shots, strategies, file_template, output_dir)

/home/asureddy_umass_edu/cs682/metrics_results/vqa//coco_VILA1.5-3b_random.csv
/home/asureddy_umass_edu/cs682/metrics_results/vqa//coco_VILA1.5-13b_random.csv
/home/asureddy_umass_edu/cs682/metrics_results/vqa//coco_VILA1.5-3b_rice.csv
/home/asureddy_umass_edu/cs682/metrics_results/vqa//coco_VILA1.5-13b_rice.csv


## Visually viewing VQA accuracy

In [19]:
n_shots = [0,2,4,8]
print("Random in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_old/VILA1.5-13b_{shot}-shot_random-examples.json"
    print(get_vqa_acc(res_file))

Random in-context examples
n-shot: 0
0.7118666666666658
n-shot: 2
0.6099999999999985
n-shot: 4
0.6109999999999982
n-shot: 8
0.5900999999999986


In [22]:
print("RICE in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_old/VILA1.5-13b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.19299999999999973
n-shot: 2
0.37836666666666696
n-shot: 4
0.4372666666666668
n-shot: 8
0.3003666666666672


In [14]:
n_shots = [0,2,4,8]
print("Random in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-13b_{shot}-shot_random-examples.json"
    print(get_vqa_acc(res_file))

Random in-context examples
n-shot: 0
0.34493333333333315
n-shot: 2
0.7311333333333345
n-shot: 4
0.7398666666666674
n-shot: 8
0.7420000000000011


In [15]:
print("RICE in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-13b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.34599999999999986
n-shot: 2
0.6881333333333347
n-shot: 4
0.6793333333333338
n-shot: 8
0.6717333333333342


In [7]:
# Vila-3b
n_shots = [0,2,4,8]
print("Random in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-3b_{shot}-shot_random-examples.json"
    print(get_vqa_acc(res_file))

Random in-context examples
n-shot: 0
0.014133333333333331
n-shot: 2
0.4854000000000001
n-shot: 4
0.6200000000000009
n-shot: 8
0.7208666666666673


In [8]:
# Vila-3b
print("RICE in-context examples")
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa/VILA1.5-3b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.013733333333333334
n-shot: 2
0.3533333333333331
n-shot: 4
0.43813333333333265
n-shot: 8
0.6136666666666678


In [16]:
print("RICE in-context examples")
n_shots = [0,2,4,8]
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_exp/VILA1.5-13b_{shot}-shot.json"
    print(get_vqa_acc(res_file))

RICE in-context examples
n-shot: 0
0.39333333333333337
n-shot: 2
0.6733333333333335
n-shot: 4
0.7
n-shot: 8
0.7033333333333335


### TextVQA dataset

In [15]:
def generate_vqa_accuracy_csv_textvqa(models, n_shots, strategies, file_template, output_dir):
    for strategy in strategies:
        strategy_str = "_random-examples" if strategy=="random" else ""
        for model in models:
            results = []
            # print(f"Processing model: {model}")
            for shot in n_shots:
                # print(f"n-shot: {shot}")
                res_file = file_template.format(model=model, shot=shot, strategy_str=strategy_str)
                try:
                    data = json.load(open(res_file))
                    outs = [x['output'].lower().strip() for x in data["outputs"]]
                    refs = [x['references'] for x in data["outputs"]]
                    accuracy = calculate_vqa_accuracy(outs, refs)
                    results.append({"n_shot": shot, "vqa_accuracy": accuracy})
                except Exception as e:
                    print(f"Error processing file {res_file}: {e}")
                    results.append({"n_shot": shot, "vqa_accuracy": None})  # Handle missing data gracefully
    
            # Create a DataFrame and save it as a CSV file
            df = pd.DataFrame(results)
            output_csv_path = f"{output_dir}/textvqa_{model}_{strategy}.csv"
            df.to_csv(output_csv_path, index=False)
            print(f"{output_csv_path}")


In [16]:
# Example usage
models = ['VILA1.5-3b', 'VILA1.5-13b']
strategies = ["random", "rice"]
n_shots = [0, 2, 4, 8]
file_template = "/home/asureddy_umass_edu/cs682/VILA_codes/results/textvqa/{model}_{shot}-shot{strategy_str}.json"
output_dir = "/home/asureddy_umass_edu/cs682/metrics_results/vqa/"

generate_vqa_accuracy_csv_textvqa(models, n_shots, strategies, file_template, output_dir)

/home/asureddy_umass_edu/cs682/metrics_results/vqa//textvqa_VILA1.5-3b_random.csv
/home/asureddy_umass_edu/cs682/metrics_results/vqa//textvqa_VILA1.5-13b_random.csv
/home/asureddy_umass_edu/cs682/metrics_results/vqa//textvqa_VILA1.5-3b_rice.csv
/home/asureddy_umass_edu/cs682/metrics_results/vqa//textvqa_VILA1.5-13b_rice.csv


In [9]:
print("Random in-context examples")
n_shots = [0,2,4,8]
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA_codes/results/textvqa/VILA1.5-3b_{shot}-shot_random-examples.json"
    data = json.load(open(res_file))
    outs = [x['output'].lower().strip() for x in data["outputs"]]
    refs = [x['references'] for x in data["outputs"]]
    print(calculate_vqa_accuracy(outs, refs))

Random in-context examples
n-shot: 0
0.06459999999999999
n-shot: 2
0.4635333333333337
n-shot: 4
0.4813333333333341
n-shot: 8
0.49033333333333395


In [10]:
print("Random in-context examples")
n_shots = [0,2,4,8]
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA_codes/results/textvqa/VILA1.5-13b_{shot}-shot_random-examples.json"
    data = json.load(open(res_file))
    outs = [x['output'].lower().strip() for x in data["outputs"]]
    refs = [x['references'] for x in data["outputs"]]
    print(calculate_vqa_accuracy(outs, refs))

Random in-context examples
n-shot: 0
0.4626000000000003
n-shot: 2
0.5276666666666667
n-shot: 4
0.5314
n-shot: 8
0.5280000000000001


In [12]:
print("RICE in-context examples")
n_shots = [0,2,4,8]
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA_codes/results/textvqa/VILA1.5-3b_{shot}-shot.json"
    data = json.load(open(res_file))
    outs = [x['output'].lower().strip() for x in data["outputs"]]
    refs = [x['references'] for x in data["outputs"]]
    print(calculate_vqa_accuracy(outs, refs))

RICE in-context examples
n-shot: 0
0.06453333333333333
n-shot: 2
0.3786666666666663
n-shot: 4
0.41153333333333314
n-shot: 8
0.4347999999999999


In [13]:
print("RICE in-context examples")
n_shots = [0,2,4,8]
for shot in n_shots:
    print(f"n-shot: {shot}")
    res_file = f"/home/asureddy_umass_edu/cs682/VILA_codes/results/textvqa/VILA1.5-13b_{shot}-shot.json"
    data = json.load(open(res_file))
    outs = [x['output'].lower().strip() for x in data["outputs"]]
    refs = [x['references'] for x in data["outputs"]]
    print(calculate_vqa_accuracy(outs, refs))

RICE in-context examples
n-shot: 0
0.46400000000000025
n-shot: 2
0.4458
n-shot: 4
0.46980000000000016
n-shot: 8
0.47233333333333377
