In [1]:
import clip
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import PIL

In [2]:
DEVICE_SETTINGS = "cuda" if torch.cuda.is_available() else "cpu"

# Calculate Metrics for CLIP Model

## Load Data

In [3]:
def load_image_markup(file_path):
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        while True:
            line = f.readline()
            if len(line) == 0:
                break
            sep_ind = line.find("\"")
            path = ("../" + line[:sep_ind]).strip()
            desc = line[sep_ind + 1:].strip()[:-1]
            data.append((path, desc))
    return data


img_markup = load_image_markup("../data/matching_images.txt")

## Metrics

In [4]:
def calc_metrics(model_name, markup, threshold):
    model, preprocess_function = clip.load(model_name, device=DEVICE_SETTINGS)

    text_list = [markup_line[1].lower() for markup_line in markup]
    # some descriptions are the same or mostly same (start from the same words)
    unique_indices = []
    for idx in range(len(text_list)):
        if idx == 0:
            unique_indices.append(idx)
            continue
        prev_started = any(txt.startswith(text_list[idx]) for txt in text_list[:idx])
        current_started = any(text_list[idx].startswith(txt) for txt in text_list[:idx])
        if prev_started or current_started:
            continue
        unique_indices.append(idx)
    unique_text_list = [text_list[idx] for idx in unique_indices]
    unique_text_list_pt = clip.tokenize(unique_text_list).to(DEVICE_SETTINGS)

    tp = 0
    fp = 0
    tn = 0
    fn = 0
    for img_idx, markup_line in tqdm(enumerate(markup), total=len(text_list)):
        with torch.no_grad():
            image_pt = preprocess_function(PIL.Image.open(markup_line[0])).unsqueeze(0).to(DEVICE_SETTINGS)

            logits_per_image_pt, logits_per_text_pt = model(image_pt, unique_text_list_pt)
            # probs = logits_per_image_pt.softmax(dim=-1).cpu().numpy()  # can't be used for not fixed list of variants
            logits_per_image = logits_per_image_pt[0].cpu().numpy()
            # logits_per_image = 1/(1 + np.exp(-logits_per_image))  # does not work with enough accuracy: all probs are too close to 1
            # print(f"dbg: {logits_per_image = }")  # use this debug printout to setup threshold manualy

        if img_idx not in unique_indices:
            assert markup_line[1].lower() == text_list[img_idx]
            txt_to_find = markup_line[1].lower()
            u_img_idx = -1
            for i in range(len(unique_text_list)):
                if unique_text_list[i].startswith(txt_to_find) or txt_to_find.startswith(unique_text_list[i]):
                    u_img_idx = i
                    break
            assert u_img_idx >= 0
            assert u_img_idx < img_idx
        else:
            u_img_idx = unique_indices.index(img_idx)
        found_indices = np.argwhere(logits_per_image >= threshold)[:, 0].tolist()

        if u_img_idx in found_indices:
            tp += 1
            if len(found_indices) > 1:
                fp += len(found_indices) - 1
            tn += len(unique_indices) - len(found_indices)
        else:
            fn += 1
            fp += len(found_indices)
            tn += len(unique_indices) - len(found_indices) - 1

    assert tp + fp + tn + fn == len(unique_text_list)*len(text_list)

    return tp, fp, tn, fn

In [5]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [6]:
MODEL_THRS = {  # these thresholds were picked up manually after brief review of logits, so they are not accurate
    "RN50": 20,
    "RN101": 46,
    # "RN50x4": 36,   # eats too much resources, skipped for full measurements
    # "RN50x16": 27,  # eats too much resources, skipped for full measurements
    # "RN50x64": 17,  # eats too much resources, skipped for full measurements
    # "ViT-B/32": 29,  # eats too much resources, skipped for full measurements
    # "ViT-B/16": 28,  # eats too much resources, skipped for full measurements
    "ViT-L/14": 23,
    # "ViT-L/14@336px": 23,  # eats too much resources, skipped for full measurements
}

In [7]:
metric_dict = {"Model": [], "TP": [], "FP": [], "TN": [], "FN": [], "Accuracy": [], "F1": []}
for model_name, model_thr in MODEL_THRS.items():
    tp, fp, tn, fn = calc_metrics(model_name, img_markup, model_thr)

    if 2*tp + fp + fn > 0:
        f1 = 2*tp/(2*tp + fp + fn)
    else:
        f1 = 0
    
    if tp + fp + tn + fn > 0:
        acc = (tp + tn)/(tp + fp + tn + fn)
    else:
        acc = 0

    metric_dict["Model"].append(model_name)
    metric_dict["TP"].append(tp)
    metric_dict["FP"].append(fp)
    metric_dict["TN"].append(tn)
    metric_dict["FN"].append(fn)
    metric_dict["Accuracy"].append(acc)
    metric_dict["F1"].append(f1)

metric_df = pd.DataFrame(metric_dict)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 311/311 [43:09<00:00,  8.33s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 311/311 [42:38<00:00,  8.23s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 311/311 [1:23:31<00:00, 16.11s/it]


In [8]:
metric_df

Unnamed: 0,Model,TP,FP,TN,FN,Accuracy,F1
0,RN50,262,3618,72577,49,0.952069,0.12503
1,RN101,194,747,75448,117,0.988707,0.309904
2,ViT-L/14,259,1395,74800,52,0.981086,0.263613
