In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from matplotlib import pyplot as plt
from tqdm import tqdm
from datasets import load_dataset
from pkgs.openai.clip import load as load_model

# Load Winoground Dataset

In [3]:
auth_token = "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"  # Replace with an auth token, which you can get from your huggingface account: Profile -> Settings -> Access Tokens -> New Token
winoground = load_dataset("facebook/winoground", use_auth_token=auth_token)["test"]

Found cached dataset winoground (C:/Users/dipti/.cache/huggingface/datasets/facebook___winoground/default/0.0.0/ce486f3e39fab90997d6f3c58c4b0103eb9c37011049ef775a465f0ab2e78d7d)


  0%|          | 0/1 [00:00<?, ?it/s]

# Load CLIP Model

In [4]:
## pretrained = True loads the original OpenAI CLIP model trained on 400M image-text pairs
clip_model, clip_processor = load_model(name = 'RN50', pretrained = False)



In [5]:
device = 'cpu'

In [6]:
## Replace with the location of the checkpoint 
## The link for checkpoints -- https://drive.google.com/drive/u/0/folders/1K0kPJZ3MA4KAdx3Fpq25dgW59wIf7M-x

checkpoint = '../checkpoints/cyclip-500K.pt/best.pt'

In [7]:
state_dict = torch.load(checkpoint, map_location = device)["state_dict"]
if(next(iter(state_dict.items()))[0].startswith("module")):
    state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
    
clip_model.load_state_dict(state_dict, strict=False)
clip_model.eval()

CLIP(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn

# Image and Text embeddings

In [8]:
def get_inputs(image, caption):
    captions = clip_processor.process_text(caption)
    pixel_values = clip_processor.process_image(image.convert("RGB"))
    return captions['input_ids'].to(device), captions['attention_mask'].to(device), pixel_values.to(device).unsqueeze(0)

In [23]:
images = torch.Tensor()
captions = torch.Tensor()

with torch.no_grad():
    for example in tqdm(winoground):
        input_c0_i0 = get_inputs(example["image_0"], example["caption_0"])
        input_c1_i1 = get_inputs(example["image_1"], example["caption_1"])

        output_c0_i0 = clip_model(input_ids = input_c0_i0[0], attention_mask = input_c0_i0[1], pixel_values = input_c0_i0[2])
        output_c1_i1 = clip_model(input_ids = input_c1_i1[0], attention_mask = input_c1_i1[1], pixel_values = input_c1_i1[2])

        images = torch.cat((images, output_c0_i0.image_embeds, output_c1_i1.image_embeds), 0)
        captions = torch.cat((captions, output_c0_i0.text_embeds, output_c1_i1.text_embeds), 0)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [03:49<00:00,  1.74it/s]


# Calculate Similarity Score

In [26]:
def clipscore(model, images, captions):
    return model.logit_scale.exp() * images @ captions.t()

In [27]:
scores = clipscore(clip_model, images, captions)

In [56]:
recall_image = {1: None, 2: None, 5: None}
recall_text = {1: None, 2: None, 5: None}

## Image Recall
for key in recall_image.keys():
    idx = torch.topk(scores, key, dim=1).indices

    recall = 0
    for i in range(len(scores)):
        if i in idx[i]:
            recall += 1
        
    recall_image[key] = recall/len(scores)

## Text Recall
for key in recall_text.keys():
    idx = torch.topk(scores, key, dim=0).indices.t()

    recall = 0
    for i in range(len(scores[0])):
        if i in idx[i]:
            recall += 1

    recall_text[key] = recall/len(scores)

In [58]:
recall_image

{1: 0.055, 2: 0.105, 5: 0.17}

In [59]:
recall_text

{1: 0.06625, 2: 0.105, 5: 0.18}