Please follow the instructions to set up the BLIP repo (hosted at [their github](https://github.com/salesforce/BLIP) before running this notebook)

In [None]:
from PIL import Image
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from models.blip_itm import blip_itm

image_size = 384

# Load BLIP model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
model = blip_itm(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device='cpu')

def load_demo_image(image_path, device):
    raw_image = Image.open(image_path).convert('RGB')
    w,h = raw_image.size
    display(raw_image.resize((w//5,h//5)))
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])
    image = transform(raw_image).unsqueeze(0).to(device)
    return image


def blip_score(image, caption):
    itm_output = model(image,caption,match_head='itm')
    itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1] # Text matching prob
    itc_score = model(image,caption,match_head='itc') # Cos sim between img and caption
    return itm_score, itc_score

In [24]:
vistext_imgs_path = "../matplotalt/evaluation/vistext_eval/matplotlib_ver_imgs/"
gallery_imgs_path = "../matplotalt/evaluation/matplotlib_gallery/alt_figs/"

# Load vistext and matplotlib gallery captions
vistext_captions_df = pd.read_json("../matplotalt/evaluation/vistext_eval/vistext_id_to_combined_captions.jsonl", orient='records', lines=True)
gallery_captions_df = pd.read_json("../matplotalt/evaluation/matplotlib_gallery/mpl_gallery_combined_captions_shuffled.jsonl", orient='records', lines=True)
print(vistext_captions_df.columns)
print(gallery_captions_df.columns)

vistext_col_names = ['human', 'heuristic', 'gpt-4-turbo-L4',
                     'gpt-4-turbo-alt-L4', 'gpt-4-turbo-table-L4',
                     'gpt-4-turbo-L3', 'gpt-4-turbo-alt-L3',
                     'gpt-4-turbo-table-L3', 'gpt-4-turbo-table-alt-L3']

gallery_col_names = ['heuristic', 'gpt-4-turbo-L3-225', 'gpt-4-turbo-alt-L3-225']
gallery_id_to_probs = defaultdict(list)
gallery_id_to_cos_sims = defaultdict(list)

Index(['image_id', 'human', 'heuristic', 'gpt-4-turbo-L4',
       'gpt-4-turbo-alt-L4', 'gpt-4-turbo-table-L4', 'gpt-4-turbo-L3',
       'gpt-4-turbo-alt-L3', 'gpt-4-turbo-table-L3',
       'gpt-4-turbo-table-alt-L3'],
      dtype='object')
Index(['heuristic', 'gpt-4-turbo-L3-225', 'gpt-4-turbo-alt-L3-225',
       'figure_id'],
      dtype='object')


In [None]:
# Calculate vistext blipscores
vistext_id_to_probs = defaultdict(list)
vistext_id_to_cos_sims = defaultdict(list)
for _, row in tqdm(vistext_captions_df.iterrows()):
    fig_id = row["image_id"]
    image = load_demo_image(image_path=f"{vistext_imgs_path}{fig_id}.png", device=device);
    for desc_type in vistext_col_names:
        if desc_type in row and row[desc_type] is not None:
            captions = row[desc_type]
            desc_type_probs = []
            desc_type_sims = []
            for cap in captions:
                cap = cap.replace("This description was generated by a language model.", "")
                cap_prob, cap_sim = blip_score(image, cap)
                desc_type_probs.append(cap_prob.item())
                desc_type_sims.append(cap_sim.item())
                #print(cap_prob.item(), cap_sim.item())
            if len(desc_type_probs) == 1:
                vistext_id_to_probs[fig_id].append(desc_type_probs[0])
                vistext_id_to_cos_sims[fig_id].append(desc_type_sims[0])
            else:
                vistext_id_to_probs[fig_id].append(desc_type_probs)
                vistext_id_to_cos_sims[fig_id].append(desc_type_sims)
        else:
            vistext_id_to_probs[fig_id].append(np.nan)
            vistext_id_to_cos_sims[fig_id].append(np.nan)
    vistext_probs_df = pd.DataFrame.from_dict(vistext_id_to_probs, orient='index',
                                              columns=[cn + "-prob" for cn in vistext_col_names])
    vistext_sims_df = pd.DataFrame.from_dict(vistext_id_to_cos_sims, orient='index',
                                              columns=[cn + "-cos-sim" for cn in vistext_col_names])
    vistext_probs_df['figure_id'] = vistext_probs_df.index
    vistext_sims_df['figure_id'] = vistext_sims_df.index
    vistext_blipscore_df = pd.merge(vistext_probs_df, vistext_sims_df, on="figure_id", how="outer")
    vistext_blipscore_df.to_json("./vistext_blipscores.jsonl", orient='records', lines=True)

In [None]:
# Calculate matplotlib gallery blipscores
gallery_id_to_probs = defaultdict(list)
gallery_id_to_cos_sims = defaultdict(list)
for _, row in tqdm(gallery_captions_df.iterrows()):
    fig_id = row["figure_id"]
    image = load_demo_image(image_path=f"{gallery_imgs_path}nb_{fig_id}.jpg", device=device);
    for desc_type in gallery_col_names:
        if desc_type in row and row[desc_type] is not None:
            caption = row[desc_type]
            caption = caption.replace("This description was generated by a language model.", "")
            cap_prob, cap_sim = blip_score(image, caption)
            gallery_id_to_probs[fig_id].append(cap_prob.item())
            gallery_id_to_cos_sims[fig_id].append(cap_sim.item())
        else:
            gallery_id_to_probs[fig_id].append(np.nan)
            gallery_id_to_cos_sims[fig_id].append(np.nan)
    gallery_probs_df = pd.DataFrame.from_dict(gallery_id_to_probs, orient='index',
                                              columns=[cn + "-prob" for cn in gallery_col_names])
    gallery_sims_df = pd.DataFrame.from_dict(gallery_id_to_cos_sims, orient='index',
                                              columns=[cn + "-cos-sim" for cn in gallery_col_names])
    gallery_probs_df['figure_id'] = gallery_probs_df.index
    gallery_sims_df['figure_id'] = gallery_sims_df.index
    gallery_blipscore_df = pd.merge(gallery_probs_df, gallery_sims_df, on="figure_id", how="outer")
    gallery_blipscore_df.to_json("./gallery_blipscores.jsonl", orient='records', lines=True)


In [47]:
print("Matplotlib gallery average blipscores: ---------------------------------------------------")
for col_name in gallery_blipscore_df.columns:
    if col_name != "figure_id":
        col_values = list(gallery_blipscore_df[col_name])
        col_values = [np.nanmean([v]) for v in col_values]
        mean_val = np.nanmean(col_values)
        print(f"{col_name}: {mean_val}")
print()
print("Vistext gallery average blipscores: ---------------------------------------------------")
for col_name in vistext_blipscore_df.columns:
    if col_name != "figure_id":
        col_values = list(vistext_blipscore_df[col_name])
        col_values = [np.nanmean([v]) for v in col_values]
        mean_val = np.nanmean(col_values)
        print(f"{col_name}: {mean_val}")


Matplotlib gallery average blipscores: ---------------------------------------------------
heuristic-prob: 0.9269914858376802
gpt-4-turbo-L3-225-prob: 0.9765950807588017
gpt-4-turbo-alt-L3-225-prob: 0.9629381057338335
heuristic-cos-sim: 0.4262510082911496
gpt-4-turbo-L3-225-cos-sim: 0.48792191035118865
gpt-4-turbo-alt-L3-225-cos-sim: 0.4717426163640188

Vistext gallery average blipscores: ---------------------------------------------------
human-prob: 0.9999340540832944
heuristic-prob: 0.9999378859456164
gpt-4-turbo-L4-prob: 0.9979466208948184
gpt-4-turbo-alt-L4-prob: 0.9989826586011643
gpt-4-turbo-table-L4-prob: 0.999617625027895
gpt-4-turbo-L3-prob: 0.9984933224901591
gpt-4-turbo-alt-L3-prob: 0.9999297380988633
gpt-4-turbo-table-L3-prob: 0.9977022701246446
gpt-4-turbo-table-alt-L3-prob: 0.9999251992644735
human-cos-sim: 0.4992284501310739
heuristic-cos-sim: 0.5037227192906295
gpt-4-turbo-L4-cos-sim: 0.5023882733156307
gpt-4-turbo-alt-L4-cos-sim: 0.5049227834571891
gpt-4-turbo-table-L