In [1]:
import clip
import torch
import numpy as np
import pandas as pd
import sklearn
from tqdm import tqdm
import PIL

In [2]:
DEVICE_SETTINGS = "cuda" if torch.cuda.is_available() else "cpu"

# Calculate Metrics for CLIP Model

## Load Data

In [3]:
def load_image_markup(file_path):
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        while True:
            line = f.readline()
            if len(line) == 0:
                break
            sep_ind = line.find("\"")
            path = ("../" + line[:sep_ind]).strip()
            desc = line[sep_ind + 1:].strip()[:-1]
            data.append((path, desc))
    return data


img_markup = load_image_markup("../data/matching_images.txt")

## Metrics

In [5]:
def calc_metrics(model_name, markup):
    model, preprocess_function = clip.load(model_name, device=DEVICE_SETTINGS)

    text_list = [markup_line[1].lower() for markup_line in markup]
    # some descriptions are the same or mostly same (start from the same words)
    unique_indices = []
    for idx in range(len(text_list)):
        if idx == 0:
            unique_indices.append(idx)
            continue
        prev_started = any(txt.startswith(text_list[idx]) for txt in text_list[:idx])
        current_started = any(text_list[idx].startswith(txt) for txt in text_list[:idx])
        if prev_started or current_started:
            continue
        unique_indices.append(idx)
    unique_text_list = [text_list[idx] for idx in unique_indices]
    unique_text_list_pt = clip.tokenize(unique_text_list).to(DEVICE_SETTINGS)

    y_true_list = []
    logit_list = []
    for img_idx, markup_line in tqdm(enumerate(markup), total=len(text_list)):
        with torch.no_grad():
            image_pt = preprocess_function(PIL.Image.open(markup_line[0])).unsqueeze(0).to(DEVICE_SETTINGS)

            logits_per_image_pt, logits_per_text_pt = model(image_pt, unique_text_list_pt)
            # probs = logits_per_image_pt.softmax(dim=-1).cpu().numpy()  # can't be used for not fixed list of variants
            logits_per_image = logits_per_image_pt[0].cpu().numpy()
            # logits_per_image = 1/(1 + np.exp(-logits_per_image))  # does not work with enough accuracy: all probs are too close to 1

        if img_idx not in unique_indices:
            assert markup_line[1].lower() == text_list[img_idx]
            txt_to_find = markup_line[1].lower()
            u_img_idx = -1
            for i in range(len(unique_text_list)):
                if unique_text_list[i].startswith(txt_to_find) or txt_to_find.startswith(unique_text_list[i]):
                    u_img_idx = i
                    break
            assert u_img_idx >= 0
            assert u_img_idx < img_idx
        else:
            u_img_idx = unique_indices.index(img_idx)

        y_true_list += [1 if idx == u_img_idx else 0 for idx in range(len(unique_text_list))]
        logit_list += logits_per_image.tolist()

    y_true_vec = np.array(y_true_list)
    logit_vec = np.array(logit_list)
    assert y_true_vec.shape == logit_vec.shape

    l_min = logit_vec.min()
    l_max = logit_vec.max()
    logit_vec = (logit_vec - l_min)/(l_max - l_min)

    # https://stats.stackexchange.com/q/287117/
    prevalence = np.count_nonzero(y_true_vec == 1, keepdims=False)/len(y_true_vec)
    assert prevalence > 0
    fpr_vec, tpr_vec, thr_vec = sklearn.metrics.roc_curve(y_true_vec, logit_vec)
    recall_vec = tpr_vec
    tnr_vec = 1 - fpr_vec

    zero_div_idxs = np.where((recall_vec*prevalence) + ((1 - tnr_vec)*(1 - prevalence)) == 0)[0]
    if zero_div_idxs.size > 0:
        # to avoid zero-division warning from numpy below
        recall_vec[zero_div_idxs] += 1e-8
    precision_vec = (recall_vec*prevalence)/((recall_vec*prevalence) + ((1 - tnr_vec)*(1 - prevalence)))

    zero_div_idxs = np.where(precision_vec + recall_vec == 0)[0]
    if zero_div_idxs.size > 0:
        # to avoid zero-division warning from numpy below
        precision_vec[zero_div_idxs] += 1e-8
        recall_vec[zero_div_idxs] += 1e-8
    f1_vec = 2*(precision_vec*recall_vec)/(precision_vec + recall_vec)

    opt_thr = thr_vec[np.argmax(f1_vec)]
    if np.isinf(opt_thr):
        opt_thr = 1

    # using ">=" below instead of ">" is extremely important, because sklearn cn return edge values for threshold
    tp = np.count_nonzero((logit_vec >= opt_thr) & (y_true_vec == 1))
    fp = np.count_nonzero((logit_vec >= opt_thr) & (y_true_vec == 0))
    tn = np.count_nonzero((logit_vec < opt_thr) & (y_true_vec == 0))
    fn = np.count_nonzero((logit_vec < opt_thr) & (y_true_vec == 1))

    opt_thr = float(opt_thr)*(l_max - l_min) + l_min

    return tp, fp, tn, fn, opt_thr

In [6]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [7]:
MODEL_THRS = {  # these thresholds were picked up manually after brief review of logits, so they are not accurate
    "RN50": 20,
    "RN101": 46,
    # "RN50x4": 36,   # eats too much resources, skipped for full measurements
    # "RN50x16": 27,  # eats too much resources, skipped for full measurements
    # "RN50x64": 17,  # eats too much resources, skipped for full measurements
    # "ViT-B/32": 29,  # eats too much resources, skipped for full measurements
    # "ViT-B/16": 28,  # eats too much resources, skipped for full measurements
    "ViT-L/14": 23,
    # "ViT-L/14@336px": 23,  # eats too much resources, skipped for full measurements
}

In [8]:
metric_dict = {"Model": [], "TP": [], "FP": [], "TN": [], "FN": [], "Threshold": [], "Accuracy": [], "F1": []}
for model_name, model_thr in MODEL_THRS.items():
    tp, fp, tn, fn, thr = calc_metrics(model_name, img_markup)

    if 2*tp + fp + fn > 0:
        f1 = 2*tp/(2*tp + fp + fn)
    else:
        f1 = 0
    
    if tp + fp + tn + fn > 0:
        acc = (tp + tn)/(tp + fp + tn + fn)
    else:
        acc = 0

    metric_dict["Model"].append(model_name)
    metric_dict["TP"].append(tp)
    metric_dict["FP"].append(fp)
    metric_dict["TN"].append(tn)
    metric_dict["FN"].append(fn)
    metric_dict["Threshold"].append(thr)
    metric_dict["Accuracy"].append(acc)
    metric_dict["F1"].append(f1)

metric_df = pd.DataFrame(metric_dict)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.83it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.57it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:11<00:00,  1.18s/it]


In [None]:
metric_df