In [5]:
import numpy as np
import pandas as pd
from pprint import pprint
from tqdm import tqdm
from collections import defaultdict
from rouge_score import rouge_scorer
from sacrebleu.metrics import BLEU, CHRF
from evaluate import load

In [6]:
rscorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=True)
bleurt = load("bleurt", module_type="metric", checkpoint="bleurt-large-512")
bertscore = load("bertscore")
bleu = BLEU()
chrf = CHRF()

def compute_scores(refs, sample):
    ref_scores = defaultdict(list)
    for ref in refs:
        rouge_scores = rscorer.score(ref, sample)
        for rouge_metric, score_result in rouge_scores.items():
            ref_scores[f"{rouge_metric}_p"].append(score_result.precision)
            ref_scores[f"{rouge_metric}_r"].append(score_result.recall)
            ref_scores[f"{rouge_metric}_f1"].append(score_result.fmeasure)
        ref_scores["bleu"].append(bleu.corpus_score([sample], [[ref]]).score)
        ref_scores["chrf"].append(chrf.corpus_score([sample], [[ref]]).score)
        ref_scores["bleurt"].append(bleurt.compute(predictions=[sample], references=[ref])["scores"][0])
        bertscores = bertscore.compute(predictions=[sample], references=[ref], lang="en")
        ref_scores["bertscore_p"] = bertscores["precision"]
        ref_scores["bertscore_r"] = bertscores["recall"]
        ref_scores["bertscore_f1"] = bertscores["f1"]
    return ref_scores









INFO:tensorflow:Reading checkpoint C:\Users\Kai\.cache\huggingface\metrics\bleurt\default\downloads\extracted\66d40c89ded88d187db3310c752ad6bc55a18f1686c772fd971b1af93164b5f5\bleurt-base-128.


INFO:tensorflow:Reading checkpoint C:\Users\Kai\.cache\huggingface\metrics\bleurt\default\downloads\extracted\66d40c89ded88d187db3310c752ad6bc55a18f1686c772fd971b1af93164b5f5\bleurt-base-128.


INFO:tensorflow:Config file found, reading.


INFO:tensorflow:Config file found, reading.


INFO:tensorflow:Will load checkpoint bert_custom


INFO:tensorflow:Will load checkpoint bert_custom


INFO:tensorflow:Loads full paths and checks that files exists.


INFO:tensorflow:Loads full paths and checks that files exists.


INFO:tensorflow:... name:bert_custom


INFO:tensorflow:... name:bert_custom


INFO:tensorflow:... vocab_file:vocab.txt


INFO:tensorflow:... vocab_file:vocab.txt


INFO:tensorflow:... bert_config_file:bert_config.json


INFO:tensorflow:... bert_config_file:bert_config.json


INFO:tensorflow:... do_lower_case:True


INFO:tensorflow:... do_lower_case:True


INFO:tensorflow:... max_seq_length:128


INFO:tensorflow:... max_seq_length:128


INFO:tensorflow:Creating BLEURT scorer.


INFO:tensorflow:Creating BLEURT scorer.


INFO:tensorflow:Creating WordPiece tokenizer.


INFO:tensorflow:Creating WordPiece tokenizer.








INFO:tensorflow:WordPiece tokenizer instantiated.


INFO:tensorflow:WordPiece tokenizer instantiated.


INFO:tensorflow:Creating Eager Mode predictor.


INFO:tensorflow:Creating Eager Mode predictor.


INFO:tensorflow:Loading model.


INFO:tensorflow:Loading model.


INFO:tensorflow:BLEURT initialized.


INFO:tensorflow:BLEURT initialized.


In [11]:
vistext_id_to_captions = pd.read_json("./vistext_eval/vistext_id_to_combined_captions.jsonl", orient="records", lines=True)
vistext_id_to_scores = defaultdict(dict) #np.load("./vistext_id_to_scores.npy", allow_pickle=True).item()
pbar = tqdm(total=len(vistext_id_to_captions))
def get_caption_scores(row):
    refs = row["human"]
    for caption_type in ["heuristic", "gpt-4-turbo-L3", "gpt-4-turbo-alt-L3", "gpt-4-turbo-table-L3", "gpt-4-turbo-table-alt-L3"]:
        if caption_type in row and row[caption_type]:
            processed_caption = row[caption_type][0]
            processed_caption = processed_caption.replace("This description was generated by a language model. ", "")
            #print(caption_type, processed_caption)
            #print("----------------------------------------------------------")
            vistext_id_to_scores[row["image_id"]][caption_type] = compute_scores(refs, processed_caption)
    pbar.update(1)

vistext_id_to_captions.apply(get_caption_scores, axis=1)
np.save("./vistext_id_to_sim_scores", vistext_id_to_scores)

  1%|          | 10/882 [01:05<1:35:27,  6.57s/it]
100%|██████████| 882/882 [1:11:27<00:00,  6.78s/it]

In [16]:
vistext_id_to_scores = np.load("./vistext_id_to_sim_scores.npy", allow_pickle=True).item()
combined_method_scores = {}
for caption_type in ["heuristic", "gpt-4-turbo-L3", "gpt-4-turbo-alt-L3", "gpt-4-turbo-table-L3", "gpt-4-turbo-table-alt-L3"]:
    combined_method_scores[caption_type] = defaultdict(list)
for image_id, caption_scores in vistext_id_to_scores.items():
    for caption_type, scores_dict in caption_scores.items():
        for score_type, scores in scores_dict.items():
            combined_method_scores[caption_type][score_type].extend(scores)
caption_type_to_avg_scores = defaultdict(dict)
for caption_type, caption_scores in combined_method_scores.items():
    for score_type, scores in caption_scores.items():
        caption_type_to_avg_scores[caption_type][score_type] = np.mean(scores)

pprint(caption_type_to_avg_scores)

defaultdict(<class 'dict'>,
            {'gpt-4-turbo-L3': {'bertscore_f1': 0.882644063964182,
                                'bertscore_p': 0.8704650796730232,
                                'bertscore_r': 0.8953872045962449,
                                'bleu': 14.14066381414324,
                                'bleurt': -0.14390061662686388,
                                'chrf': 46.430975596845954,
                                'rouge1_f1': 0.4636736219061632,
                                'rouge1_p': 0.3759233302022657,
                                'rouge1_r': 0.6377347115494284,
                                'rouge2_f1': 0.2260586446410901,
                                'rouge2_p': 0.1821134742350561,
                                'rouge2_r': 0.31470937233073093,
                                'rougeL_f1': 0.32315676028608786,
                                'rougeL_p': 0.2609146964897651,
                                'rougeL_r': 0.44769307784171214,
      