# sigLIP Baseline

In [19]:
import dataset as ds
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as tt
import numpy as np

In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [21]:

model_type = "base"

if model_type == "large":
    model = AutoModel.from_pretrained("google/siglip-large-patch16-384").to(device).eval()
    processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-384")
else:
    model = AutoModel.from_pretrained("google/siglip-base-patch16-224").to(device).eval()
    processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")


In [22]:
config = model.config.to_dict()

print(f'Device: {device} ({torch.cuda.device_count()} gpus)')
print(f"Model parameters: {np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print(f"Input resolution: {config['vision_config']['image_size']}")
print(f"Context length: {config['text_config']['max_position_embeddings']}")
print(f"Vocab size: {config['text_config']['vocab_size']}")

Device: cuda (1 gpus)
Model parameters: 203,155,970
Input resolution: 224
Context length: 64
Vocab size: 32000


In [23]:
scale = tt.Resize((336, 336))
tensor = tt.PILToTensor()
image_composed = tt.transforms.Compose([tensor])

test_set = ds.VisualWSDDataset(mode="test", image_transform=image_composed)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

In [24]:
# if k is 1 gives all instances with the correct prediction as top prediction
# if k > 1 the correct prediction is in the top k predictions of the model
def hit(results, k):
    counter = 0

    for r in results:
        sims = r[1]
        sorted = np.argsort(sims)[::-1][:k]

        if r[0] in sorted:
            counter += 1

    return counter / len(results)

def mrr(results):
    sum = 0

    for r in results:
        sims = r[1]
        sorted = np.argsort(sims)[::-1]
        sum += 1/(np.where(sorted==r[0])[0][0]+1)

    return sum / len(results)

In [25]:
cos_distance = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
a = torch.randn(1, 10)
cos_distance(a, a)

tensor([1.])

In [54]:
import pickle
#context_glossar = {}

with open('context_glossar_0_463.pkl', 'rb') as f:
    context_glosar = pickle.load(f)

with open('context_snys_0_463.pkl', 'rb') as f:
    context_syns = pickle.load(f)

In [27]:
context_glosar

{'football': "Any of various games played with a ball (round or oval) in which two teams try to kick or carry or propel the ball into each other's goal",
 'seed': 'A small hard fruit',
 'eating': 'Take in solid food',
 'the web': 'An intricate network suggesting something that was formed by weaving or interweaving',
 'person': 'A human being',
 'statue': 'A sculpture representing a human or animal',
 'ear': 'The sense organ for hearing and equilibrium',
 'wild': 'Marked by extreme lack of restraint or control',
 'embrace': "Include in scope; include as part of something broader; have as one's sphere or territory",
 'grinding': 'Press or grind with a crushing noise',
 'winged': 'Having wings or as if having wings of a specified kind;',
 'rowing': 'Propel with oars',
 'insect': 'Small air-breathing arthropod',
 'trotting': 'Ride at a trot',
 'water': 'Binary compound that occurs at room temperature as a clear colorless odorless tasteless liquid; freezes into ice below 0 degrees centigrad

In [60]:
def test():
    
    with torch.no_grad():

        results = []

        for a,data in enumerate(test_loader):
            images = data["imgs"]
            text = data["label_context"][0]
            correct_idx = data["correct_idx"].item()


            # cheating
            #context_only = str(data['label_context'][0]).replace(str(data['label'][0]), '').strip()
            #print(context_only)
            #if context_only in context_syns:
            #    syns = context_syns[context_only]
            #    extra = ""
            #    if len(syns) > 0:
            #        extra = syns[0].replace('_', ' ')
                #text = "An image of a '" + str(data['label'][0]) + "'. In '" + context_only + "'. "# + glos
            #    text = "A iamage of '" + data['label'][0] + ", " + context_only + ", " + extra + "'."
            #    if len(text) > 210:
            #        text = text[:210]

            # cheating end


            print("----------------------------")
            print("batch: " + str(a+1) + "/" + str(len(test_loader)))
            print("label: " + str(text))
            print("correct index: " + str(correct_idx))

            input_text = processor(text=text, return_tensors="pt", padding="max_length")

            input_imgs = []
            for img in images:
                input_imgs.append(processor(images=img, return_tensors="pt", padding="max_length"))

            img_output = []
            for img in input_imgs:
                img_output.append(model.get_image_features(**img.to(device)))

            text_output = model.get_text_features(**input_text.to(device))

            cos_distance = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

            #print("text output shape: " + str(text_output.shape))
            #print("img output shape: " + str(img_output[0].shape))
        
            
            sims = []
            for i in range(len(img_output)):
                sims.append(cos_distance(text_output, img_output[i]).cpu().item()*100)

            sims = np.array(sims)
            #min_idx = np.argmin(sims)
            #max_idx = np.argmax(sims)

            print("sims: " + str(sims))
            #print("min val: " + str(sims[min_idx]) + " min index: " + str(min_idx))
            #print("max val: " + str(sims[max_idx]) + " max index: " + str(max_idx))

            results.append((correct_idx, sims))

            #if a > 1:
            #    break

    return results

In [61]:
results = test()

football
----------------------------
batch: 1/463
label: A iamage of 'goal, football, football game'.
correct index: 8
sims: [ 2.03710534 -7.89147615  1.41916964 -3.16982493 -5.75918555  4.23861593
  0.68208831  2.45082639 10.10461673 -0.82674008]
seed
----------------------------
batch: 2/463
label: A iamage of 'mustard, seed, seeded player'.
correct index: 0
sims: [12.18896657 -0.92041641  7.21838623  5.60388602 -1.13415569 -0.74560186
 -0.19483957  0.49675549  3.5162881  -1.97200123]
eating
----------------------------
batch: 3/463
label: A iamage of 'seat, eating, feeding'.
correct index: 5
sims: [ 4.03467193 -0.05563977 -1.13101453  2.94558145  2.31133234  4.98416461
 -3.2236401  -1.97077319  1.94346868 -0.91226771]
the web
----------------------------
batch: 4/463
label: A iamage of 'navigate, the web, '.
correct index: 6
sims: [-9.86098647 -1.34696597 -0.85547343 -8.24827552 -7.42986277 -2.70425379
  3.49928439  1.0051636  -3.38911824 -3.97237986]
person
-----------------------

In [62]:
print("hit1: " + str(hit(results, 1)))
print("mrr: " + str(mrr(results)))

hit1: 0.6544276457883369
mrr: 0.7779363365216494


In [None]:
results[1]

(0,
 array([12.63373047, -2.04947777,  9.7230278 ,  1.60970464, -0.37627704,
        -0.37521236,  0.58928672,  5.23877144,  2.21057273, -1.97367258]))