In [None]:
!pip install sentence_transformers

In [14]:
import json
from sentence_transformers import SentenceTransformer, util
import numpy as np
import pandas as pd
import ast
import string
import os
from tqdm import tqdm
from constants import inferences_folder, results_folder

# Load the pre-trained sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')


def parse_plausible_answers(text):
    """
    Parses a string that represents a set of plausible answers into a Python set.

    Parameters:
    - text (str): A string that looks like a Python set of strings (e.g., "{'no', 'yes'}")

    Returns:
    - set: A set containing the plausible answers.
    """
    try:
        return ast.literal_eval(text)
    except ValueError:
        return set()


def extract_answer(text):
    """
    Extracts the answer from a generated answer text string that follows the format:
    '...Answer:  motorcycle'

    Parameters:
    - text (str): The complete text from which to extract the answer.

    Returns:
    - str: The extracted answer, trimmed and converted to lowercase.
    """
    # Find the index of the "Answer:" substring and extract everything after it
    answer_start = text.rfind(':') + 1

    answer = text[answer_start:].strip().lower()
    answer = answer.split(" ")[0].strip(string.punctuation)
    return answer

    # print(f"Text answer: '{text[answer_start:]}'")
    # if answer_start >= 0:
    #     # Extract the answer and split by any new lines or extraneous text
        
    #     print(f"gen answer:\n'{answer}'",)
    #     return answer
    # return 'NA'

def evaluate_exact_match(data):
    def is_correct(row):
        # print(row['cleaned_generated_answer'])
        # print("original answer : ", row['Answer'])
        # print(row['plausible_answers_set'])
        return row['cleaned_generated_answer'] in row['plausible_answers_set']
    correct_count = data.apply(is_correct, axis=1).sum()
    total = len(data)
    return correct_count / total

def evaluate_semantic_similarity(data, threshold=0.7):

    def is_semantically_correct(row):
        generated_answer_embedding = model.encode(row['cleaned_generated_answer'], convert_to_tensor=True)
        plausible_answers_embeddings = model.encode(list(row['plausible_answers_set']), convert_to_tensor=True)

        # Compute cosine similarities and check if any are above the threshold
        similarities = util.pytorch_cos_sim(generated_answer_embedding, plausible_answers_embeddings)
        max_similarity = np.max(similarities.cpu().numpy())
        return max_similarity >= threshold

    correct_counts = data.apply(is_semantically_correct, axis=1).sum()
    total = len(data)
    return correct_counts / total


def compute_accuracies(data):
    # Exact Match Accuracy
    exact_match_accuracy = evaluate_exact_match(data)

    # Semantic Similarity Accuracy
    semantic_similarity_accuracy = evaluate_semantic_similarity(data)

    return {
        'exact_match_accuracy': exact_match_accuracy,
        'semantic_similarity_accuracy': semantic_similarity_accuracy
    }


def split_and_evaluate(data):
    # Split the data based on answer type
    types = ['other', 'yes/no', 'number']
    results = {}
    total_count = 0

    for answer_type in types:
        subset = data[data['Answer Type'] == answer_type]
        if len(subset) > 0:
            accuracies = compute_accuracies(subset)
            results[answer_type] = {
                'accuracies': accuracies,
                'count': len(subset)
            }
            total_count += len(subset)
        else:
            results[answer_type] = {
                'accuracies': {
                    'exact_match_accuracy': None,
                    'semantic_similarity_accuracy': None
                },
                'count': 0
            }

    # Compute weighted accuracies
    weighted_exact_match = sum(
        info['accuracies']['exact_match_accuracy'] * info['count'] for info in results.values() if
        info['accuracies']['exact_match_accuracy'] is not None) / total_count
    weighted_semantic = sum(
        info['accuracies']['semantic_similarity_accuracy'] * info['count'] for info in results.values() if
        info['accuracies']['semantic_similarity_accuracy'] is not None) / total_count

    overall_accuracies = {
        'weighted_exact_match_accuracy': weighted_exact_match,
        'weighted_semantic_accuracy': weighted_semantic
    }

    # Combine results and overall accuracies
    results['overall'] = overall_accuracies

    return results

def load_and_merge_data(ground_truth_csv, generated_csv):
    # Load the data from CSV files
    ground_truth_data = pd.read_csv(ground_truth_csv)
    generated_data = pd.read_csv(generated_csv)

    # Clean the 'Generated Answer' in the generated data
    generated_data['cleaned_generated_answer'] = generated_data['Generated Answer'].apply(extract_answer)

    # Parse the plausible answers in the ground truth data
    ground_truth_data['plausible_answers_set'] = ground_truth_data['Plausible answers'].apply(parse_plausible_answers)

    # Merge the datasets on 'Image ID' and 'Question' using an inner join
    merged_data = pd.merge(ground_truth_data, generated_data, how='inner', on=['Image ID', 'Question'], suffixes=('_truth', '_generated'))

    # Optionally, print columns and some data for verification
    # print("Columns in merged dataset:", merged_data.columns.values.tolist())
    # print("Sample 'Image ID' data from merged dataset:", merged_data['Image ID'].head())

    return merged_data

def save_results_to_json(filename, results):
    with open(filename, 'w') as f:
        json.dump(results, f, indent=4)

def find_csv_files(directory, suffix='.csv'):
    """
    Finds all files in the specified directory with the given suffix.

    Parameters:
    - directory (str): The directory to search for files.
    - suffix (str): The file suffix to search for.

    Returns:
    - list of str: A list of full file paths that match the suffix.
    """
    return [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith(suffix)]


def process_files(ground_truth_csv, generated_files, destination_folder):
    for generated_csv in tqdm(generated_files, desc="Processing files"):
        print(f"Processing File : {generated_csv}\n")
        # Extract the base filename without '10k' and with modified suffix
        # base_filename = os.path.basename(generated_csv).replace('10k_', '').replace('.csv', '_results.json')

        base_filename = os.path.basename(generated_csv).replace('.csv', '_results.json')

        # Load and merge data
        merged_data = load_and_merge_data(ground_truth_csv, generated_csv)

        merged_data.to_csv("dummy.csv", index = True)

        # Evaluate and compute results
        evaluation_results = split_and_evaluate(merged_data)

        # Save to JSON file with the new filename
        save_results_to_json(f"{destination_folder}/{base_filename}", evaluation_results)
        print(f"Results saved to {destination_folder}/{base_filename}")

# Constant ground truth file

In [15]:
ground_truth_csv = '../data/10k_mapped_question_answers.csv'

In [16]:
import os

folders = [(folder, os.path.join(inferences_folder, folder)) for folder in os.listdir(inferences_folder) if not os.path.isfile(os.path.join(inferences_folder, folder)) and not folder.startswith(".") ]
folders

[('blip_llama', '../inferences\\blip_llama'),
 ('blip_mistral', '../inferences\\blip_mistral'),
 ('blip_yolo_llama', '../inferences\\blip_yolo_llama'),
 ('blip_yolo_llama_quantized_templates',
  '../inferences\\blip_yolo_llama_quantized_templates'),
 ('blip_yolo_llama_unquantized_templates',
  '../inferences\\blip_yolo_llama_unquantized_templates'),
 ('blip_yolo_mistral', '../inferences\\blip_yolo_mistral'),
 ('blip_yolo_mistral_quantized_templates',
  '../inferences\\blip_yolo_mistral_quantized_templates'),
 ('blip_yolo_mistral_unquantized_templates',
  '../inferences\\blip_yolo_mistral_unquantized_templates'),
 ('yolo_llama', '../inferences\\yolo_llama'),
 ('yolo_mistral', '../inferences\\yolo_mistral')]

In [17]:
# list(filter(lambda x: x.contains("quantized"), folders))
# subset = list(filter(lambda x: "quant" in x[0] in x[0], folders))

In [18]:
# for folder, folder_path in subset:
#     generated_files = find_csv_files(folder_path)
#     for file in generated_files:
#         df= pd.read_csv(file)
#         df.rename({"Model Output": "Generated Answer", 
#                   "Generated Caption": "Caption", 
#                   "Generated Detections": "Detections"}, axis=1, inplace=True)
#         df.drop(["Answer Type", "Question Type", "Answer", "Plausible answers", "Image file"], axis=1, inplace=True)
#         df.to_csv(file, index = False)

In [19]:
# for folder, folder_path in subset:
#     # folder= f"{inferences_folder}/{folder}"
#     # List of generated files
#     generated_files = find_csv_files(folder_path)
#     process_files(ground_truth_csv, generated_files, f"{results_folder}/{folder}")

In [20]:
folders = [('yolo_llama', '../inferences/yolo_llama'), ('yolo_mistral', '../inferences/yolo_mistral')]
for folder, folder_path in folders:
    # folder= f"{inferences_folder}/{folder}"
    # List of generated files
    generated_files = find_csv_files(folder_path)
    process_files(ground_truth_csv, generated_files, f"{results_folder}/{folder}")

Processing files:   0%|          | 0/3 [00:00<?, ?it/s]

Processing File : ../inferences/yolo_llama\10k_enhanced_prompt_with_generation_cfg_answers.csv



Processing files:  33%|███▎      | 1/3 [00:26<00:52, 26.45s/it]

Results saved to ../results/yolo_llama/10k_enhanced_prompt_with_generation_cfg_answers_results.json
Processing File : ../inferences/yolo_llama\10k_generation_cfg_answers.csv



Processing files:  67%|██████▋   | 2/3 [00:54<00:27, 27.61s/it]

Results saved to ../results/yolo_llama/10k_generation_cfg_answers_results.json
Processing File : ../inferences/yolo_llama\10k_generation_cfg_prompt_restriction_answers.csv



Processing files: 100%|██████████| 3/3 [01:26<00:00, 28.81s/it]


Results saved to ../results/yolo_llama/10k_generation_cfg_prompt_restriction_answers_results.json


Processing files:   0%|          | 0/4 [00:00<?, ?it/s]

Processing File : ../inferences/yolo_mistral\10k_default_answers.csv



Processing files:  25%|██▌       | 1/4 [05:47<17:21, 347.23s/it]

Results saved to ../results/yolo_mistral/10k_default_answers_results.json
Processing File : ../inferences/yolo_mistral\10k_enhanced_prompt_with_generation_cfg_answers.csv



Processing files:  50%|█████     | 2/4 [11:40<11:41, 350.58s/it]

Results saved to ../results/yolo_mistral/10k_enhanced_prompt_with_generation_cfg_answers_results.json
Processing File : ../inferences/yolo_mistral\10k_generation_cfg_answers.csv



Processing files:  75%|███████▌  | 3/4 [17:19<05:45, 345.63s/it]

Results saved to ../results/yolo_mistral/10k_generation_cfg_answers_results.json
Processing File : ../inferences/yolo_mistral\10k_generation_cfg_prompt_restriction_answers.csv



Processing files: 100%|██████████| 4/4 [22:12<00:00, 333.13s/it]

Results saved to ../results/yolo_mistral/10k_generation_cfg_prompt_restriction_answers_results.json



