In [1]:
import argparse
import json
import collections
import random
import pandas as pd    
from nltk.translate.bleu_score import sentence_bleu
from eval_metrics.evaluate_metrics import calculate_exactmatch, calculate_f1score, bleu, calculate_appearance_with_normalization
from tabulate import tabulate
from eval_metrics.glossary import *

def evaluate(gt, pred, return_pred=False):
    bleu_scores = collections.defaultdict(list)
    exact_scores = collections.defaultdict(list)
    f1_scores = collections.defaultdict(list)
    question_analysis = []  # To store detailed scores for each question
    
    num_open = 0
    for gt_item, pred_item in zip(gt, pred):
        try:
            gt_results = gt_item['conversations']
        except KeyError:
            gt_results = gt_item['conversatons']
        if not pred_item.__contains__('gt'):
            pred_item['gt'] = gt_results[1]['value']

        gt_value = gt_results[1]['value'].lower()
        pred_value = pred_item['text'].lower()
        if pred_value.startswith('assistant:'):
            pred_value = pred_value[10:].strip()

        gt_value = normalize_word(gt_value)
        pred_value = normalize_word(pred_value)

        if gt_item['answer_type'] == 'OPEN':
            num_open += 1

            question_id = pred_item['question_id']
            exact_match = calculate_exactmatch(pred_value, gt_value)
            f1, precision, recall = calculate_f1score(pred_value, gt_value)
            bleu = sentence_bleu(references=[str(gt_value).split()], hypothesis=str(pred_value).split())
            bleu_1 = sentence_bleu(references=[str(gt_value).split()], hypothesis=str(pred_value).split(), weights=(1, 0, 0, 0))
            bleu_2 = sentence_bleu(references=[str(gt_value).split()], hypothesis=str(pred_value).split(), weights=(0, 1, 0, 0))
            bleu_3 = sentence_bleu(references=[str(gt_value).split()], hypothesis=str(pred_value).split(), weights=(0, 0, 1, 0))

            # Store detailed scores for each question
            question_analysis.append({
                'question_id': question_id,
                'exact_match': exact_match,
                'f1': f1,
                'precision': precision,
                'recall': recall,
                'bleu': bleu,
                'bleu_1': bleu_1,
                'bleu_2': bleu_2,
                'bleu_3': bleu_3
            })

            exact_scores['hit'].append(exact_match)
            f1_scores['f1'].append(f1)
            f1_scores['precision'].append(precision)
            f1_scores['recall'].append(recall)
            bleu_scores['bleu_score'].append(bleu)
            bleu_scores['bleu_score_1'].append(bleu_1)
            bleu_scores['bleu_score_2'].append(bleu_2)
            bleu_scores['bleu_score_3'].append(bleu_3)

    # Calculate aggregate metrics
    exact_score = sum(exact_scores['hit']) / len(exact_scores['hit']) if num_open else 0
    f1_score = sum(f1_scores['f1']) / len(f1_scores['f1']) if num_open else 0
    precision = sum(f1_scores['precision']) / len(f1_scores['precision']) if num_open else 0
    recall = sum(f1_scores['recall']) / len(f1_scores['recall']) if num_open else 0
    bleu_score = sum(bleu_scores['bleu_score']) / len(bleu_scores['bleu_score']) if num_open else 0
    bleu_score_1 = sum(bleu_scores['bleu_score_1']) / len(bleu_scores['bleu_score_1']) if num_open else 0
    bleu_score_2 = sum(bleu_scores['bleu_score_2']) / len(bleu_scores['bleu_score_2']) if num_open else 0
    bleu_score_3 = sum(bleu_scores['bleu_score_3']) / len(bleu_scores['bleu_score_3']) if num_open else 0

    # Print summary metrics
    print(f'num_open {num_open}')
    print(tabulate(
        [
            ['exact match score', exact_score * 100], 
            ['f1 score', f1_score * 100], 
            ['precision', precision * 100], 
            ['recall', recall * 100], 
            ['bleu_score', bleu_score * 100], 
            ['bleu_score_1', bleu_score_1 * 100], 
            ['bleu_score_2', bleu_score_2 * 100], 
            ['bleu_score_3', bleu_score_3 * 100]
        ], 
        headers=['Metric', 'Performance']
    ))

    # Sort question analysis by lowest F1 score and return the top 30
    low_performance_questions = sorted(question_analysis, key=lambda x: x['f1'])
    if return_pred:
        return low_performance_questions, pred
    return low_performance_questions

def load_jsonl(path):
    data=[]
    with open(path, 'r', encoding='utf-8') as reader:
        for line in reader:
            data.append(json.loads(line))
    return data 

In [26]:
# def eval_open_file(gt_file, pred_file):
visual_enhance_ratio=0.08
bbox_ratio=0.03
epoch_num=6
ROOT_PATH="/data/aofei"
dataset="Slake"

dir=f"llava_med/moe_img_dense_all_query/all_expert_8_16_rank16/lora_{visual_enhance_ratio}_bbox_{bbox_ratio}/epoch{epoch_num}"
gt_file = f"{ROOT_PATH}/hallucination/{dataset}/data/test.json"
pred_file = f"{ROOT_PATH}/hallucination/mitigation/{dataset}/{dir}/inference/pred.jsonl"

dataset = gt_file.split("/")[-2]
print(f"\n========\n {dataset}")

gt = json.load(open(gt_file, 'r'))
# candidate = json.load(open(args.candidate, 'r'))
pred = load_jsonl(pred_file)

gt_ids = [item['id'] for item in gt]
pred_ids = [item['question_id'] for item in pred]
num_gt_ids, num_pred_ids = len(gt_ids), len(pred_ids)
print(f'num_gt_ids: {num_gt_ids} || num_pred_ids: {num_pred_ids}')
# import pdb; pdb.set_trace()
assert gt_ids == pred_ids, "please make sure pred and gt are exactly matched"

# perform evaluation
results = evaluate(gt, pred)


 data
num_gt_ids: 1061 || num_pred_ids: 1061


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


num_open 706
Metric               Performance
-----------------  -------------
exact match score       80.5309
f1 score                80.7272
precision               81.0188
recall                  80.971
bleu_score               3.58223
bleu_score_1            79.7757
bleu_score_2            20.3638
bleu_score_3             5.95555


In [27]:
len(results)

706

In [28]:
results[100]

{'question_id': 12905,
 'exact_match': 0.0,
 'f1': 0,
 'precision': 0,
 'recall': 0,
 'bleu': 0,
 'bleu_1': 0,
 'bleu_2': 0,
 'bleu_3': 0}

In [29]:
pred_dict = dict()
for i in pred:
    pred_dict[i['question_id']] = i

In [31]:
for i in results[100:200]:
    # print(i, pred_dict[i['question_id']])
    print(pred_dict[i['question_id']])

{'question_id': 12905, 'prompt': '<image>\nHow many femoral heads are shown in this image?', 'text': 'Assistant: 2', 'gt': '0', 'answer_id': 'cbooaNQ8NpeGqipP4wR5HJ', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}}
{'question_id': 12929, 'prompt': '<image>\nWhich is bigger in this image, small bowel or colon?', 'text': 'Assistant: Colon', 'gt': 'Small Bowel', 'answer_id': 'V9KXwkz6MRUh2abKXmpjXm', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}}
{'question_id': 12935, 'prompt': '<image>\nWhat is the function of the organ on the top of this image?', 'text': 'Assistant: Store feces, excrete feces', 'gt': 'Absorb nutrients, secrete enzymes, digest food', 'answer_id': 'iqaALAan5pfvxMjkESfevo', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}}
{'question_id': 12938, 'prompt': '<image>\nWhich part of the human body is the organ located in the image?', 'text': 'Assistant: Abdomen', 'gt': 'Pelvic Cavity', 'answer_id': 'Jku53oeNvjtK9VzyKoLWge', 'model_id': '/data/aofei/L

In [33]:
# def eval_open_file(gt_file, pred_file):
visual_enhance_ratio=0.08
bbox_ratio=0.03
epoch_num=6
ROOT_PATH="/data/aofei"
dataset="Slake"

# dir=f"llava_med/moe_img_dense_all_query/all_expert_8_16_rank16/lora_{visual_enhance_ratio}_bbox_{bbox_ratio}/epoch{epoch_num}"
gt_file = f"{ROOT_PATH}/hallucination/{dataset}/data/test.json"
# pred_file = f"{ROOT_PATH}/hallucination/mitigation/{dataset}/{dir}/inference/pred.jsonl"

dataset = gt_file.split("/")[-2]
print(f"\n========\n {dataset}")

pred_file = "/data/aofei/hallucination/mitigation/Slake/llava_med/lora/epoch6/inference/pred.jsonl"

gt = json.load(open(gt_file, 'r'))
# candidate = json.load(open(args.candidate, 'r'))
pred = load_jsonl(pred_file)

gt_ids = [item['id'] for item in gt]
pred_ids = [item['question_id'] for item in pred]
num_gt_ids, num_pred_ids = len(gt_ids), len(pred_ids)
print(f'num_gt_ids: {num_gt_ids} || num_pred_ids: {num_pred_ids}')
# import pdb; pdb.set_trace()
assert gt_ids == pred_ids, "please make sure pred and gt are exactly matched"

# perform evaluation
results_lora, pred = evaluate(gt, pred, return_pred=True)

pred_dict = dict()
for i in pred:
    pred_dict[i['question_id']] = i


 data
num_gt_ids: 1061 || num_pred_ids: 1061
num_open 706
Metric               Performance
-----------------  -------------
exact match score       81.9666
f1 score                82.2
precision               82.5391
recall                  82.5009
bleu_score               4.00716
bleu_score_1            81.1315
bleu_score_2            21.2204
bleu_score_3             6.22461


In [36]:
for i in results_lora[:100]:
    print(pred_dict[i['question_id']])

{'question_id': 11968, 'prompt': '<image>\nWhat diseases are included in the picture?', 'text': 'Assistant: Infiltration', 'answer_id': 'Xw2YyEKn2KtquKweHeyuoh', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}, 'gt': 'Cardiomegaly'}
{'question_id': 11969, 'prompt': '<image>\nWhere is/are the abnormality located?', 'text': 'Assistant: Right Lung, Upper Left', 'answer_id': 'TZrRYtWLU48C8YdVnRwvPV', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}, 'gt': 'Center'}
{'question_id': 11971, 'prompt': '<image>\nWhich organ is abnormal, heart or lung?', 'text': 'Assistant: Lung', 'answer_id': 'kh2GKdDXJbT29xcrXaVqRq', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}, 'gt': 'Heart'}
{'question_id': 11984, 'prompt': '<image>\nWhat diseases are included in the picture?', 'text': 'Assistant: Atelectasis', 'answer_id': 'oPcg6bnQUBs398vLktmauT', 'model_id': '/data/aofei/LLM/llava_med', 'metadata': {}, 'gt': 'Pneumonia'}
{'question_id': 11991, 'prompt': '<image>\nWhat diseases ar