In [5]:
import open_clip
import torch


clip_model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="laion2b_s34b_b79k"
)
device = torch.device("cpu")
clip_model = clip_model.to(device)
clip_model.eval()
clip_tokenizer = open_clip.get_tokenizer("ViT-B-32")



  from .autonotebook import tqdm as notebook_tqdm


In [26]:
def get_sims(a_labels,b_labels):

    clip_encoded_a_labels = clip_model.encode_text(clip_tokenizer(a_labels))
    clip_encoded_b_labels = clip_model.encode_text(clip_tokenizer(b_labels))
    with torch.no_grad(), torch.autocast("cuda"):
        text_sims = (clip_encoded_a_labels @ clip_encoded_b_labels.T)
        
    for alabel,sims in zip(a_labels,text_sims):
        print(f"similarity {alabel}")   
        for blabel,sim in zip(b_labels,sims):
            print(f"    -{blabel}: {sim}")
    return text_sims

In [27]:
a_labels = ["left eye","right eye","mouth"]
b_labels = ["left pupil","right pupil"]

get_sims(a_labels,b_labels)

similarity left eye
    -left pupil: 58.51355743408203
    -right pupil: 51.71540832519531
similarity right eye
    -left pupil: 54.226497650146484
    -right pupil: 61.25897216796875
similarity mouth
    -left pupil: 38.34929656982422
    -right pupil: 41.58185577392578


tensor([[58.5136, 51.7154],
        [54.2265, 61.2590],
        [38.3493, 41.5819]])

In [17]:
import numpy as np


def get_cor_labels(a_labels, b_labels):

    if len(b_labels) > len(a_labels):
        tmp = b_labels
        b_labels = a_labels
        a_labels = tmp

    clip_encoded_a_labels = clip_model.encode_text(clip_tokenizer(a_labels))
    clip_encoded_b_labels = clip_model.encode_text(clip_tokenizer(b_labels))
    with torch.no_grad(), torch.autocast("cuda"):
        text_sims = clip_encoded_a_labels @ clip_encoded_b_labels.T

    text_sims = sorted(zip(text_sims, a_labels), key=lambda d: max(d[0]), reverse=True)
    label_correspondance = {}
    used_simlabel_indexes = []
    b_labels = np.array(b_labels)

    for sims, alabel in text_sims:
        m_sims = sims
        if len(used_simlabel_indexes) != 0:
            mask = torch.ones(sims.size(0), dtype=torch.bool)
            mask[used_simlabel_indexes] = False
            m_sims = sims[mask]
            if len(m_sims) == 0:
                break
            b_labels = np.delete(b_labels, used_simlabel_indexes)

        index_max = torch.argmax(m_sims)
        label_correspondance[alabel] = b_labels[index_max]
        used_simlabel_indexes.append(index_max.item())
    return label_correspondance

In [19]:
a_labels = ["left eye","right eye","mouth"]
b_labels = ["left pupil","right pupil"]

get_cor_labels(a_labels,b_labels)

{'right eye': np.str_('right pupil'), 'left eye': np.str_('left pupil')}