In [1]:
import clip
import torch
import numpy as np
import transformers
from tqdm import tqdm
import PIL

In [2]:
# # works only with old version of transformers
# !pip install -U transformers==4.37.1

In [3]:
# # there is a bug in ipykernel that prevents downloading of BAAI/bge-large-en-v1.5, so we need to be sure that it is updated
# # https://github.com/huggingface/xet-core/issues/526
# !pip install -U ipykernel>=7.1.0

# Calculate Metrics for UForm Model

In [4]:
def get_metrics_txt_emb(text):
    METRICS_TXT_EMB_MODEL_NAME = "BAAI/bge-large-en-v1.5"

    tokenizer = transformers.AutoTokenizer.from_pretrained(METRICS_TXT_EMB_MODEL_NAME)
    model = transformers.AutoModel.from_pretrained(METRICS_TXT_EMB_MODEL_NAME)
    model.eval()
    
    encoded_input = tokenizer([text], padding=True, truncation=True, return_tensors='pt')
    
    with torch.no_grad():
        model_output = model(**encoded_input)
        sentence_embeddings = model_output[0][:, 0]

    return torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)[0].numpy()

## Load Data

In [5]:
def load_image_markup(file_path):
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        while True:
            line = f.readline()
            if len(line) == 0:
                break
            sep_ind = line.find("\"")
            path = ("../" + line[:sep_ind]).strip()
            desc = line[sep_ind + 1:].strip()[:-1]
            data.append((path, desc))
    return data


img_markup = load_image_markup("../data/matching_images.txt")

## Metrics

In [6]:
def calc_metrics(markup, model, processor_obj, threshold):
    PROMPT = "List the objects in the photo in one sentence, no information about background needed"

    text_list = [markup_line[1].lower() for markup_line in markup]
    # some descriptions are the same or mostly same (start from the same words)
    unique_indices = []
    for idx in range(len(text_list)):
        if idx == 0:
            unique_indices.append(idx)
            continue
        prev_started = any(txt.startswith(text_list[idx]) for txt in text_list[:idx])
        current_started = any(text_list[idx].startswith(txt) for txt in text_list[:idx])
        if prev_started or current_started:
            continue
        unique_indices.append(idx)
    unique_text_list = [text_list[idx] for idx in unique_indices]
    emb_unique_text_list = [get_metrics_txt_emb(txt) for txt in unique_text_list]

    tp = 0
    fp = 0
    tn = 0
    fn = 0
    for img_idx, markup_line in enumerate(markup):
        inputs = processor_obj(text=[PROMPT], images=[PIL.Image.open(markup_line[0])], return_tensors="pt")
        with torch.inference_mode():
             output = model.generate(
                **inputs,
                do_sample=False,
                use_cache=True,
                max_new_tokens=256,
                eos_token_id=151645,
                pad_token_id=processor_obj.tokenizer.pad_token_id
                
            )

        prompt_len = inputs["input_ids"].shape[1]
        decoded_out = processor_obj.batch_decode(output[:, prompt_len:], skip_special_tokens=True)[0]

        emb_decoded_out = get_metrics_txt_emb(decoded_out)
        distance_per_image = np.array([np.linalg.norm(emb_unique_text - emb_decoded_out) for emb_unique_text in emb_unique_text_list])
        # print(f"dbg: {distance_per_image = }")  # use this debug printout to setup threshold manualy

        if img_idx not in unique_indices:
            assert markup_line[1].lower() == text_list[img_idx]
            txt_to_find = markup_line[1].lower()
            u_img_idx = -1
            for i in range(len(unique_text_list)):
                if unique_text_list[i].startswith(txt_to_find) or txt_to_find.startswith(unique_text_list[i]):
                    u_img_idx = i
                    break
            assert u_img_idx >= 0
            assert u_img_idx < img_idx
        else:
            u_img_idx = unique_indices.index(img_idx)
        found_indices = np.argwhere(distance_per_image <= threshold)[:, 0].tolist()

        if u_img_idx in found_indices:
            tp += 1
            if len(found_indices) > 1:
                fp += len(found_indices) - 1
            tn += len(unique_indices) - len(found_indices)
        else:
            fn += 1
            fp += len(found_indices)
            tn += len(unique_indices) - len(found_indices) - 1

    assert tp + fp + tn + fn == len(unique_text_list)*len(text_list)

    return tp, fp, tn, fn

In [7]:
MODEL_NAME = "unum-cloud/uform-gen2-qwen-500m"
model = transformers.AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
processor_func = transformers.AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
SIM_THR = 0.88
tp, fp, tn, fn = calc_metrics(img_markup, model, processor_func, SIM_THR)

if 2*tp + fp + fn > 0:
    f1 = 2*tp/(2*tp + fp + fn)
else:
    f1 = 0

if tp + fp + tn + fn > 0:
    acc = (tp + tn)/(tp + fp + tn + fn)
else:
    acc = 0

In [9]:
print(f"Model: {MODEL_NAME}")
print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")
print(f"Accuracy: {acc}")
print(f"F1: {f1}")

Model: unum-cloud/uform-gen2-qwen-500m
TP: 146, FP: 785, TN: 75410, FN: 165
Accuracy: 0.9875826732543853
F1: 0.23510466988727857
