# CLIP Baseline

In [1]:
import dataset as ds
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as tt
import numpy as np

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:

model_type = "base"

if model_type == "large":
    model = AutoModel.from_pretrained("openai/clip-vit-large-patch14-336").to(device).eval()
    processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
else:
    model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
    processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")




In [4]:
config = model.config.to_dict()

print(f'Device: {device} ({torch.cuda.device_count()} gpus)')
print(f"Model parameters: {np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print(f"Input resolution: {config['vision_config']['image_size']}")
print(f"Context length: {config['text_config']['max_position_embeddings']}")
print(f"Vocab size: {config['text_config']['vocab_size']}")

Device: cuda (1 gpus)
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


In [5]:
scale = tt.Resize((336, 336))
tensor = tt.PILToTensor()
image_composed = tt.transforms.Compose([tensor])

test_set = ds.VisualWSDDataset(mode="test", image_transform=image_composed)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

In [6]:
# if k is 1 gives all instances with the correct prediction as top prediction
# if k > 1 the correct prediction is in the top k predictions of the model
def hit(results, k):
    counter = 0

    for r in results:
        sims = r[1]
        sorted = np.argsort(sims)[::-1][:k]

        if r[0] in sorted:
            counter += 1

    return counter / len(results)

def mrr(results):
    sum = 0

    for r in results:
        sims = r[1]
        sorted = np.argsort(sims)[::-1]
        sum += 1/(np.where(sorted==r[0])[0][0]+1)

    return sum / len(results)

In [7]:
cos_distance = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
a = torch.randn(1, 10)
cos_distance(a, a)

tensor([1.])

In [8]:
def test():
    
    with torch.no_grad():

        results = []

        for a,data in enumerate(test_loader):
            images = data["imgs"]
            text = data["label_context"][0]
            correct_idx = data["correct_idx"].item()


            print("----------------------------")
            print("batch: " + str(a+1) + "/" + str(len(test_loader)))
            print("label: " + str(text))
            print("correct index: " + str(correct_idx))

            input_text = processor(text=text, return_tensors="pt", padding=True)

            input_imgs = []
            for img in images:
                input_imgs.append(processor(images=img, return_tensors="pt", padding=True))

            img_output = []
            for img in input_imgs:
                img_output.append(model.get_image_features(**img.to(device)))

            text_output = model.get_text_features(**input_text.to(device))

            cos_distance = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

            #print("text output shape: " + str(text_output.shape))
            #print("img output shape: " + str(img_output[0].shape))
        
            
            sims = []
            for i in range(len(img_output)):
                sims.append(cos_distance(text_output, img_output[i]).cpu().item()*100)

            sims = np.array(sims)
            #min_idx = np.argmin(sims)
            #max_idx = np.argmax(sims)

            print("sims: " + str(sims))
            #print("min val: " + str(sims[min_idx]) + " min index: " + str(min_idx))
            #print("max val: " + str(sims[max_idx]) + " max index: " + str(max_idx))

            results.append((correct_idx, sims))

            #if a > 1:
            #    break

    return results

In [9]:
results = test()

----------------------------
batch: 1/463
label: football goal
correct index: 8
sims: [25.65433383 20.5376938  23.53613377 22.56401032 19.67912167 24.54849631
 20.69906294 23.71804118 31.57707453 24.31022674]
----------------------------
batch: 2/463
label: mustard seed
correct index: 0
sims: [29.9520582  22.5094825  28.67417037 21.29943669 20.04111707 26.85717344
 22.86458313 27.03888118 21.51581049 19.29246932]
----------------------------
batch: 3/463
label: eating seat
correct index: 5
sims: [21.3126123  21.16441876 19.67388988 20.30988187 21.62075192 26.03778839
 23.05617332 20.11105418 23.52854908 23.07288647]
----------------------------
batch: 4/463
label: navigate the web
correct index: 6
sims: [17.71313697 18.8382864  19.44748163 17.87516177 17.42998213 20.91193199
 24.00656492 22.04234451 21.70941234 18.01635921]
----------------------------
batch: 5/463
label: butterball person
correct index: 2
sims: [19.83281076 17.89187342 27.03762054 15.51927179 16.83579683 21.13995105
 

In [10]:
print("hit1: " + str(hit(results, 1)))
print("mrr: " + str(mrr(results)))

hit1: 0.5874730021598272
mrr: 0.7299864582261988


In [11]:
results[1]

(0,
 array([29.9520582 , 22.5094825 , 28.67417037, 21.29943669, 20.04111707,
        26.85717344, 22.86458313, 27.03888118, 21.51581049, 19.29246932]))