In [1]:
# !conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=10.2 -c pytorch
# !pip install ftfy==5.8
# !conda install transformers
# !pip install git+https://github.com/openai/CLIP.git
# import matplotlib.pyplot as plt
# !pip install ipywidgets
# !git clone https://github.com/FreddeFrallan/Multilingual-CLIP
# !cd Multilingual-CLIP
# pip install rich 

import numpy as np

import warnings
warnings.filterwarnings("ignore")

import clip


# Validation pipeline 

In [1]:
import torch 

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

device

device(type='cuda')

In [2]:
from rich import print 

In [4]:
# Read the images from of the dataset 
import os 

img_folder = 'photos/'

if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
    os.makedirs(img_folder, exist_ok=True)

In [5]:
import json 

data = [] 

with open("en_ar_XTD10_edited_v2.jsonl") as filino:

    for file_i in filino:

        dic_obj = json.loads(file_i)
        data.append(dic_obj)

In [6]:
print("Dataset size is: ", len(data) )

In [7]:
print(data[:10])

In [12]:
Check_id_duplication = [] 

In [13]:
for idx, data_obj in enumerate(data):

    Check_id_duplication.append(data_obj["id"])

In [14]:
# If the len is 1000, there is no duplicates

len(set(Check_id_duplication)) == 1000

True

In [15]:
# data = [
#     {'image_id': 0, 'id': 391895, 'caption': 'رجل يرتدي خوذة حمراء على دراجة بخارية صغيرة على طريق ترابي'},
#     {'image_id': 1, 'id': 522418, 'caption': 'امرأة ترتدي شبكة على رأسها تقطع كعكة'},
#     {'image_id': 2, 'id': 184613, 'caption': 'طفل يحمل مظلة مزهرة ويأكل ثورًا'},
# ]

# Sort the list of dictionaries based on the 'id' key
sorted_data = sorted(data, key=lambda x: x['id'])

print(sorted_data[:20])
# # Print the sorted list
# for item in sorted_data:
#     print(item)

In [16]:
# get only 10 examples
# sorted_data

In [17]:
len(sorted_data)

1000

In [18]:
print(sorted_data[:10])

In [19]:
image_name_list = []

for lin in sorted_data:
    # print(lin["image_name"])
    image_name_list.append(lin["image_name"])

In [20]:
print(image_name_list)

In [21]:
sorted_data[0]

{'caption_en': 'major league baseball game with player from pittsburgh pirates crossing home plate',
 'caption_ar': 'تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول',
 'image_name': 'COCO_train2014_000000061844.jpg',
 'id': 61844}

In [22]:
# Create a mapping dictionary between the ids and paths

id2path = {}


for im_path, sort_sample in zip(image_name_list, sorted_data):


    # print(json.loads(lin)["text"])
    # print(im_path.split("_")[-1].split(".")[0])

    input_str = im_path.split("_")[-1].split(".")[0]
    # print(input_str)
    result = int(input_str.lstrip('0'))
    # Check the ids
    if sort_sample['id'] != result:
        print("stop ........................................................")
    id2path[result] = im_path

    # print(result)

In [23]:
id2path

{61844: 'COCO_train2014_000000061844.jpg',
 61849: 'COCO_train2014_000000061849.jpg',
 61850: 'COCO_train2014_000000061850.jpg',
 61852: 'COCO_train2014_000000061852.jpg',
 61854: 'COCO_train2014_000000061854.jpg',
 61865: 'COCO_train2014_000000061865.jpg',
 61867: 'COCO_train2014_000000061867.jpg',
 61877: 'COCO_train2014_000000061877.jpg',
 61881: 'COCO_train2014_000000061881.jpg',
 61892: 'COCO_train2014_000000061892.jpg',
 61895: 'COCO_train2014_000000061895.jpg',
 61904: 'COCO_train2014_000000061904.jpg',
 61911: 'COCO_train2014_000000061911.jpg',
 61918: 'COCO_train2014_000000061918.jpg',
 61919: 'COCO_train2014_000000061919.jpg',
 61936: 'COCO_train2014_000000061936.jpg',
 61945: 'COCO_train2014_000000061945.jpg',
 61946: 'COCO_train2014_000000061946.jpg',
 61949: 'COCO_train2014_000000061949.jpg',
 61951: 'COCO_train2014_000000061951.jpg',
 61966: 'COCO_train2014_000000061966.jpg',
 61982: 'COCO_train2014_000000061982.jpg',
 61992: 'COCO_train2014_000000061992.jpg',
 62017: 'CO

In [24]:
# Check if each image file exists in the folder

folder_path = "photos/XTD10_dataset"

missing_images = []

for image_path in image_name_list:
    full_image_path = os.path.join(folder_path, image_path)
    if not os.path.exists(full_image_path):
        missing_images.append(image_path)

if missing_images:
    print("The following images are missing:")
    for image_path in missing_images:
        print(image_path)
else:
    print("All images are present in the folder.")

In [25]:
# Delete the images that are not included on the testing dataset 

import os


not_exist_paths = []
exist_paths = [] 

# Get a list of all files in the folder
all_files = os.listdir(folder_path)

# Remove any files in the folder that are not in the list of image paths
for file_name in all_files:
    if file_name not in image_name_list:
        file_path = os.path.join(folder_path, file_name)
        os.remove(file_path)
        # print(f"Removed: {file_path}")
        not_exist_paths.append(file_path)

    elif file_name in image_name_list:

        exist_paths.append(file_name)


destroy_images = set(not_exist_paths).difference(set(exist_paths))


print("img_names", len(all_files))
print("destroy_images", len(destroy_images))
print("not_exist_paths", len(not_exist_paths))
print("remaining images", len(all_files)- len(destroy_images))

# print("Finished removing unwanted images.")

Define the the text model 

In [26]:
import pickle

import torch
import transformers

class MultilingualClipEdited(torch.nn.Module):
    def __init__(self, model_name, tokenizer_name, head_name, weights_dir='data/weights/', cache_dir=None,in_features=None,out_features=None):
        super().__init__()
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.head_path = weights_dir + head_name

        self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir)
        print(self.tokenizer )
        self.transformer = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
        self.clip_head = torch.nn.Linear(in_features=in_features, out_features=out_features)
        self._load_head()

    def forward(self, txt):
        txt_tok = self.tokenizer(txt, padding=True, return_tensors='pt')
        embs = self.transformer(**txt_tok)[0]
        
        print("embs shape: ", embs.shape)

        att = txt_tok['attention_mask']

        print("att shape: ", att.shape)
    
        embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]

        print("embs after att shape: ", embs.shape)

        return self.clip_head(embs)

    def _load_head(self):
        with open(self.head_path, 'rb') as f:
            lin_weights = pickle.loads(f.read())
        self.clip_head.weight = torch.nn.Parameter(torch.tensor(lin_weights[0]).float().t())
        self.clip_head.bias = torch.nn.Parameter(torch.tensor(lin_weights[1]).float())

AVAILABLE_MODELS = {
    'M-BERT-Distil-40': {
        'model_name': 'M-CLIP/M-BERT-Distil-40',
        'tokenizer_name': 'M-CLIP/M-BERT-Distil-40',
        'head_name': 'M-BERT Distil 40 Linear Weights.pkl'
    },

    'M-BERT-Base-69': {
        'model_name': 'M-CLIP/M-BERT-Base-69',
        'tokenizer_name': 'M-CLIP/M-BERT-Base-69',
        'head_name': 'M-BERT-Base-69 Linear Weights.pkl'
    },

    'Swe-CLIP-500k': {
        'model_name': 'M-CLIP/Swedish-500k',
        'tokenizer_name': 'M-CLIP/Swedish-500k',
        'head_name': 'Swedish-500k Linear Weights.pkl'
    },

    'Swe-CLIP-2M': {
        'model_name': 'M-CLIP/Swedish-2M',
        'tokenizer_name': 'M-CLIP/Swedish-2M',
        'head_name': 'Swedish-2M Linear Weights.pkl'
    },
    
    'M-BERT-Base-ViT-B': {
        'model_name': 'M-CLIP/M-BERT-Base-ViT-B',
        'tokenizer_name': 'M-CLIP/M-BERT-Base-ViT-B',
        'head_name': 'M-BERT-Base-69-ViT Linear Weights.pkl'
    },
    'M-BERT-Base-ViT-B-ours': {
        'model_name': 'Arabic-Clip/m-bert-base-ViT-B-32-trained-mclip-data',
        'tokenizer_name': 'Arabic-Clip/m-bert-base-ViT-B-32-trained-mclip-data',
        'head_name': 'postTransformation_layer_linear_latest.pickle'
    },

    'M-BERT-Base-ViT-B-local': {
        'model_name': '/home/think3/Desktop/2. tf_testing_araclip/Testing_conversion_tf_to_pt/M-BERT-Base-ViT-B',
        'tokenizer_name': '/home/think3/Desktop/2. tf_testing_araclip/Testing_conversion_tf_to_pt/M-BERT-Base-ViT-B',
        'head_name': 'M-BERT-Base-69-ViT Linear Weights.pkl'
    },

    'arabert-large-vit-base-32-epoch-16': {
        'model_name': 'Arabic-Clip/arabert-large-vit-base-32-epoch-16',
        'tokenizer_name': 'Arabic-Clip/arabert-large-vit-base-32-epoch-16',
        'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-32-16_.pickle'
    },
    'arabert-large-vit-base-32-epoch-21': {
        'model_name': 'Arabic-Clip/arabert-large-vit-base-32-epoch-21',
        'tokenizer_name': 'Arabic-Clip/arabert-large-vit-base-32-epoch-21',
        'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-32-21_.pickle'
    },

    'arabert-large-vit-base-32-epoch-26': {
        'model_name': 'Arabic-Clip/arabert-large-vit-base-32-epoch-21',
        'tokenizer_name': 'Arabic-Clip/arabert-large-vit-base-32-epoch-21',
        'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-32-26_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-6': {
        'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-6-trained-1M-corrupted',
        'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-6-trained-1M-corrupted',
        'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-6_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-11': {
        'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-11-trained-1M-corrupted',
        'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-11-trained-1M-corrupted',
        'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-11_.pickle'
    },
    
    'arabert-large-vit-B-16-plus-epoch-16': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-16-trained-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-16-trained-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-16_.pickle'
    },
    
    'arabert-large-vit-B-16-plus-epoch-21': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-21-trained-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-21-trained-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-21_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-26': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-26-trained-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-26-trained-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-26_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-31': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-31-trained-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-31-trained-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-31_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-36': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-36-trained-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-36-trained-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-36_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-41': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-41-trained-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-41-trained-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-41_.pickle'
    },


    'arabert-large-vit-B-16-plus-epoch-4': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-4-trained-3M-5M',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-4-trained-3M-5M',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-4_.pickle'
    },

    'arabert-large-vit-B-16-plus-epoch-5-3M-5M': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-5-trained-3M-5M-from-scratch',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-5-trained-3M-5M-from-scratch',
    'head_name': 'arabert_v2_vit_B_16_plusheads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-5_.pickle'
    },


    'arabert-large-vit-B-16-plus-epoch-23-on-top-1M-corrupted': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-23-trained-3M-5M-on-top-1M-corrupted',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-23-trained-3M-5M-on-top-1M-corrupted',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-23_.pickle'
    },
    
    'arabert-large-vit-B-16-plus-mscoc-11': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-11-trained-mscoco-training',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-11-trained-mscoco-training',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-11_mscoco_.pickle'
    },
    
    'arabert-large-vit-B-16-plus-mscoc-60': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-60-trained-mscoco-training',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-60-trained-mscoco-training',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-60_.pickle'
    },
    'arabert-large-vit-B-16-plus-mscoc-60-32': {
    'model_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-60-trained-mscoco-training-fp32',
    'tokenizer_name': 'Arabic-Clip/arabertv2-Vit-B-16-plus-epoch-60-trained-mscoco-training-fp32',
    'head_name': 'heads_of_the_model_bert-large-arabertv2-Vit-B-16-plus-240-60_32.pickle'
    },

    'arbertv2-large-vit-B-16-epoch-54-2M_3M_5M': {
    'model_name': 'Arabic-Clip/arbertv2-Vit-B-16-plus-epoch-54-2M_3M_5M',
    'tokenizer_name': 'Arabic-Clip/arbertv2-Vit-B-16-plus-epoch-54-2M_3M_5M',
    'head_name': 'ARBERTv2_vit_B_16_plusheads_of_the_model_ARBERTv2-Vit-B-16-plus-240-54_.pickle'
    },

    'arbertv2-Vit-B-16-plus-epoch-200-msoco':{
    'model_name': 'Arabic-Clip/arbertv2-Vit-B-16-plus-epoch-200-msoco',
    'tokenizer_name': 'Arabic-Clip/arbertv2-Vit-B-16-plus-epoch-200-msoco',
    'head_name': 'ARBERTv2_vit_B_16_plusheads_of_the_model_ARBERTv2-Vit-B-16-plus-240-200_.pickle'
    },
    

    'bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-55-trained-2M':{
    'model_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-55-trained-2M',
    'tokenizer_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-55-trained-2M',
    'head_name': 'arabertv2-vit-B-16-siglibheads_of_the_model_arabertv2-ViT-B-16-SigLIP-512-55_.pickle'
    },

    'bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M':{
    'model_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M',
    'tokenizer_name': 'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M',
    'head_name': 'arabertv2-vit-B-16-siglibheads_of_the_model_arabertv2-ViT-B-16-SigLIP-512-155_.pickle'
    },

    'bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M-mscoco-200':{
    'model_name':'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M-mscoco-200',
    'tokenizer_name':'Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M-mscoco-200',
    'head_name': 'arabertv2-vit-B-16-siglib-mscocoarabertv2-vit-B-16-siglibheads_of_the_model_arabertv2-ViT-B-16-SigLIP-512-200_.pickle'
    },

    'Arabert-v2-base-ViT-B-16-SigLIP-512-2M':{
    'model_name':'Arabic-Clip/Arabert-v2-base-ViT-B-16-SigLIP-512-2M',
    'tokenizer_name':'Arabic-Clip/Arabert-v2-base-ViT-B-16-SigLIP-512-2M',
    'head_name': 'Arabert-v2-base-ViT-B-16-SigLIP-512-2M.pickle'
    },
}


In [27]:
def load_model(name, cache_dir=None,in_features=None,out_features=None):
    config = AVAILABLE_MODELS[name]
    print(config)
    return MultilingualClipEdited(**config, cache_dir=cache_dir, in_features= in_features, out_features=out_features)

In [28]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# device

In [29]:
# !wget https://huggingface.co/Arabic-Clip/Arabert-v2-base-ViT-B-16-SigLIP-512-2M/resolve/main/Arabert-v2-base-ViT-B-16-SigLIP-512-2M.pickle


In [30]:
# Open the pickle file in binary read mode

pickle_file_path = 'Arabert-v2-base-ViT-B-16-SigLIP-512-2M.pickle'  # Replace with the actual path to your pickle file
with open(pickle_file_path, 'rb') as file:
    loaded_content = pickle.load(file)
    print(len(loaded_content))
    print(loaded_content[0].shape)
    print(loaded_content[1].shape)

FileNotFoundError: [Errno 2] No such file or directory: 'Arabert-v2-base-ViT-B-16-SigLIP-512-2M.pickle'

In [31]:

# Text model name 
text_model = load_model('Arabert-v2-base-ViT-B-16-SigLIP-512-2M', in_features= 768, out_features=768)


# Define the language model with lambda 

language_model = lambda queries: np.asarray(text_model(queries).detach().to('cpu')) 

In [32]:
text_model


MultilingualClipEdited(
  (transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(64000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

### Define the image model 

In [33]:
# !pip install open_clip_torch

In [32]:
# clip_model, compose = clip.load('RN50x4')
# import torch
# import open_clip
import torch
import torch.nn.functional as F
from urllib.request import urlopen
from PIL import Image
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8


device = "cuda" if torch.cuda.is_available() else "cpu"

print("Device: ", device)

# clip_model, _, compose = open_clip.create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
# tokenizer = open_clip.get_tokenizer('ViT-B-16-plus-240')
# clip_model.to(device)


clip_model, compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
tokenizer = get_tokenizer('hf-hub:timm/ViT-B-16-SigLIP-512')

Device:  cuda


In [33]:
compose

Compose(
    Resize(size=(512, 512), interpolation=bicubic, max_size=None, antialias=None)
    <function _convert_to_rgb at 0x7f42436e85e0>
    ToTensor()
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
)

In [34]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [35]:
clip_model.to(device)

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

### Defind  the image model 

In [36]:
image_model = lambda images: np.asarray(clip_model.encode_image(images.to(device)).float().detach().to('cpu'))

# Utils

In [37]:
# Define the needed libraries in the code 

from tqdm.notebook import tqdm
import os 

from PIL import Image

### Defind a dataset class for images 

In [38]:

class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        self.total_imgs = image_name_list
        print(self.total_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def get_image_name(self, idx):

        return self.total_imgs[idx]

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc)

        return self.transform(image)

### Defind a dataset class for text dataset  

In [39]:
class SimpleTextDataset(torch.utils.data.Dataset):

    def __init__(self, texts):
        """Define  the class init"""
        self.texts = texts

    def __len__(self):
        """Return the length of the text dataset"""
        return len(self.texts)

    def __getitem__(self, idx):
        """Get the item based on index"""
        return self.texts[idx]

In [40]:
def text_encoder(text):
    """Normalize the text embeddings"""
    embedding = language_model(text)
    embedding = embedding / np.linalg.norm(embedding)

    return embedding

def precompute_text_features(loader):
    """Compute the text embeddings of the whole dataset based on the loader provided"""
    text_features = []

    for _, (texts) in enumerate(tqdm(loader)):

        embedding = language_model(texts)
        embedding = embedding / np.linalg.norm(embedding)

        text_features.extend(embedding)

    return np.array(text_features)

In [41]:
def precompute_image_features(loader):
    image_features = []
    
    for i, (images) in enumerate(tqdm(loader)):

        features = image_model(images)

        features = features / np.linalg.norm(features)
        image_features.extend(features)

    return np.array(image_features)

In [42]:
def show_images(image_list):
    for im_path in image_list:
        print(im_path)
        display(Image.open(im_path))

In [43]:
# text = 'بجعة تطفو أسفل النهر بالقارب'

# image_paths = find_image(text, dataset, image_features, n=3)
# show_images(image_paths)

Build the image dataset 

In [44]:
dataset = CustomDataSet("photos/XTD10_dataset", transform=compose)

['COCO_train2014_000000061844.jpg', 'COCO_train2014_000000061849.jpg', 'COCO_train2014_000000061850.jpg', 'COCO_train2014_000000061852.jpg', 'COCO_train2014_000000061854.jpg', 'COCO_train2014_000000061865.jpg', 'COCO_train2014_000000061867.jpg', 'COCO_train2014_000000061877.jpg', 'COCO_train2014_000000061881.jpg', 'COCO_train2014_000000061892.jpg', 'COCO_train2014_000000061895.jpg', 'COCO_train2014_000000061904.jpg', 'COCO_train2014_000000061911.jpg', 'COCO_train2014_000000061918.jpg', 'COCO_train2014_000000061919.jpg', 'COCO_train2014_000000061936.jpg', 'COCO_train2014_000000061945.jpg', 'COCO_train2014_000000061946.jpg', 'COCO_train2014_000000061949.jpg', 'COCO_train2014_000000061951.jpg', 'COCO_train2014_000000061966.jpg', 'COCO_train2014_000000061982.jpg', 'COCO_train2014_000000061992.jpg', 'COCO_train2014_000000062017.jpg', 'COCO_train2014_000000062029.jpg', 'COCO_train2014_000000062030.jpg', 'COCO_train2014_000000062031.jpg', 'COCO_train2014_000000062038.jpg', 'COCO_train2014_000

In [45]:
# check if the image_paths sorted_data in the same order of the image dataset:


for i, item in enumerate(sorted_data):

    if item['image_name'] != dataset.get_image_name(i):
        print("stop")
        break


In [46]:
len(dataset)

1000

### Define the image_loder

In [47]:
image_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    drop_last=False)

### Define the text_loder

In [48]:
text_dataset = SimpleTextDataset([elem["caption_ar"] for elem in sorted_data])

text_loader = torch.utils.data.DataLoader(
    text_dataset,
    batch_size=64,
    shuffle=False)

In [49]:
# Check this to utalize the GPU memory in the images 
# https://discuss.pytorch.org/t/not-using-multiprocessing-but-getting-cuda-error-re-forked-subprocess/54610/8

In [50]:
import numpy as np

In [52]:
# !pip install ipywidgets

In [53]:
image_features = precompute_image_features(image_loader)

  0%|          | 0/63 [00:00<?, ?it/s]

In [54]:
image_emb_path = 'image_features.pickle'

In [55]:
text_emb_path = 'text_features.pickle'

In [56]:
import pickle


with open(image_emb_path, 'wb') as handle:
    pickle.dump(image_features, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [57]:
import pickle

with open(image_emb_path, 'rb') as handle:
    image_features_new = pickle.load(handle)

image_features_new

array([[-2.9428354e-03,  4.5608878e-04,  8.0113132e-03, ...,
        -5.5744208e-04,  1.3973909e-02, -2.0840494e-03],
       [ 7.1692630e-03, -3.6415604e-03, -7.4449009e-03, ...,
         1.9552663e-03, -2.5131479e-03, -9.7969100e-03],
       [-2.2590715e-03,  7.8871623e-03,  1.7792708e-03, ...,
         2.2906056e-03, -1.4292294e-03,  2.5328281e-03],
       ...,
       [-1.8433719e-04, -1.0952966e-02, -1.2227783e-03, ...,
        -3.1431387e-03, -1.2975362e-02, -5.3121448e-03],
       [ 5.6465985e-03, -5.5815902e-04, -2.4672919e-03, ...,
        -9.0176007e-03,  1.0915064e-03, -1.2074071e-03],
       [-1.2275180e-02, -1.4984320e-03,  1.1497841e-02, ...,
        -1.2154064e-02, -4.6428460e-05, -1.9285224e-02]], dtype=float32)

In [58]:
text_features = precompute_text_features(text_loader)

text_features

  0%|          | 0/16 [00:00<?, ?it/s]

embs shape:  torch.Size([64, 47, 768])
att shape:  torch.Size([64, 47])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 32, 768])
att shape:  torch.Size([64, 32])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 29, 768])
att shape:  torch.Size([64, 29])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 29, 768])
att shape:  torch.Size([64, 29])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 48, 768])
att shape:  torch.Size([64, 48])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 32, 768])
att shape:  torch.Size([64, 32])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 28, 768])
att shape:  torch.Size([64, 28])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 30, 768])
att shape:  torch.Size([64, 30])
embs after att shape:  torch.Size([64, 768])
embs shape:  torch.Size([64, 38, 768])
att shape:  torch.Size([6

array([[ 0.00465822, -0.00300487,  0.002295  , ..., -0.00158944,
         0.0019236 , -0.00171525],
       [ 0.00326471, -0.00549089, -0.00176563, ...,  0.00211979,
         0.00070003, -0.00064166],
       [ 0.00025349,  0.00138306, -0.00057553, ...,  0.00080644,
        -0.00581361, -0.00391068],
       ...,
       [ 0.00169416, -0.00047155,  0.00539855, ...,  0.00153624,
        -0.00576969,  0.00184906],
       [-0.00063449,  0.00550029,  0.00108185, ...,  0.00402879,
         0.00355745, -0.00296559],
       [-0.0089858 , -0.0035024 ,  0.00538795, ..., -0.00300943,
        -0.00760316, -0.00625956]], dtype=float32)

In [59]:
import pickle


with open(text_emb_path, 'wb') as handle:
    pickle.dump(text_features, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [60]:

with open(text_emb_path, 'rb') as handle:
    text_features_new = pickle.load(handle)

text_features_new

array([[ 0.00465822, -0.00300487,  0.002295  , ..., -0.00158944,
         0.0019236 , -0.00171525],
       [ 0.00326471, -0.00549089, -0.00176563, ...,  0.00211979,
         0.00070003, -0.00064166],
       [ 0.00025349,  0.00138306, -0.00057553, ...,  0.00080644,
        -0.00581361, -0.00391068],
       ...,
       [ 0.00169416, -0.00047155,  0.00539855, ...,  0.00153624,
        -0.00576969,  0.00184906],
       [-0.00063449,  0.00550029,  0.00108185, ...,  0.00402879,
         0.00355745, -0.00296559],
       [-0.0089858 , -0.0035024 ,  0.00538795, ..., -0.00300943,
        -0.00760316, -0.00625956]], dtype=float32)

In [61]:
image_features_new

array([[-2.9428354e-03,  4.5608878e-04,  8.0113132e-03, ...,
        -5.5744208e-04,  1.3973909e-02, -2.0840494e-03],
       [ 7.1692630e-03, -3.6415604e-03, -7.4449009e-03, ...,
         1.9552663e-03, -2.5131479e-03, -9.7969100e-03],
       [-2.2590715e-03,  7.8871623e-03,  1.7792708e-03, ...,
         2.2906056e-03, -1.4292294e-03,  2.5328281e-03],
       ...,
       [-1.8433719e-04, -1.0952966e-02, -1.2227783e-03, ...,
        -3.1431387e-03, -1.2975362e-02, -5.3121448e-03],
       [ 5.6465985e-03, -5.5815902e-04, -2.4672919e-03, ...,
        -9.0176007e-03,  1.0915064e-03, -1.2074071e-03],
       [-1.2275180e-02, -1.4984320e-03,  1.1497841e-02, ...,
        -1.2154064e-02, -4.6428460e-05, -1.9285224e-02]], dtype=float32)

In [62]:
text_features_new.shape

(1000, 768)

In [63]:
text_features_new[0][:]

array([ 4.65822406e-03, -3.00486828e-03,  2.29499745e-03,  1.13717071e-03,
        1.38161296e-04, -9.60945617e-04, -4.02267603e-03, -2.11910577e-03,
       -1.00674061e-03,  1.96714560e-03,  4.19034390e-04, -4.26500337e-03,
        2.01693410e-03,  8.40586144e-03,  1.37070147e-03, -2.33583059e-03,
       -6.58896577e-04,  4.78198240e-03, -6.63301171e-06, -3.03384359e-03,
        8.51550512e-03,  4.44406550e-03,  1.02539023e-03,  8.78327468e-04,
       -5.09183563e-04,  4.75151057e-04, -5.01624029e-03, -3.33817769e-03,
       -1.66290498e-03,  2.83397874e-03, -1.21037674e-03,  2.79206014e-03,
       -4.95796383e-04,  6.22524414e-04, -9.82807833e-04, -1.28187716e-03,
       -3.98709235e-04, -1.58936542e-03, -5.16854692e-04,  3.47565580e-03,
        2.21085385e-03,  2.94526108e-04, -5.02343429e-03,  1.58968891e-04,
        1.23343046e-03,  7.15488393e-04, -4.74927452e-04,  7.96543551e-04,
       -6.58445351e-04, -4.42503579e-03, -4.55476716e-03,  5.20684291e-03,
        1.66639697e-03, -

In [64]:
image_features_new[0][:]

array([-2.94283545e-03,  4.56088776e-04,  8.01131316e-03, -5.11716213e-03,
       -9.17849771e-04, -1.13405436e-02, -1.73408596e-03, -8.88058171e-03,
       -1.58235207e-02, -5.62402711e-04, -8.57300404e-03, -6.95090182e-03,
       -7.96150824e-04,  9.19742696e-03, -1.83630688e-03, -7.88219552e-03,
        1.41508621e-03,  9.83527489e-03, -5.43090864e-04,  2.02741998e-04,
        7.45872548e-03,  6.69360859e-03, -8.85412883e-05,  6.18533741e-05,
       -2.17115367e-03,  4.52068960e-03,  1.85395975e-03,  9.49008577e-03,
       -2.79319449e-03,  1.50575321e-02,  3.05480929e-03,  1.34937828e-02,
        5.09243133e-03, -7.57530378e-03, -7.81269860e-04, -5.84949367e-03,
       -1.02735050e-02,  6.94800587e-03,  5.79292770e-04, -4.63431841e-03,
       -6.44110376e-03,  1.02458699e-02,  1.14735952e-02,  7.11104972e-03,
       -6.00119121e-03,  6.76115524e-05,  3.16453137e-04, -2.76640966e-03,
        6.56399131e-03,  9.25119035e-04, -1.21195626e-03,  9.13234148e-03,
        6.10010291e-04,  

In [65]:
# # Take a look later over this

# logit_scale = clip_model.logit_scale.exp().float().detach().to('cpu')
# print(logit_scale)
# logit_scale * text_features_new

In [66]:
# logit_scale_val = logit_scale.item()

In [67]:
def get_path_coco(image_id):
    # image_id = int(image_id)
    # print(type(image_id))

    im_path = id2path[image_id]
    
    return f"photos/XTD10_dataset/{im_path}" # f"photos/val2014/COCO_val2014_{image_id:012d}.jpg"

In [68]:
import numpy as np

In [69]:
mat_indx_mrr = np.zeros((1000,1000),dtype=np.int64)

In [70]:
mat_indx_mrr.shape

(1000, 1000)

In [71]:
mat_indx_mrr

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [72]:
collect_rr_testing = []

In [73]:
# Check which axis the for loop get back
# So, it loop over the raws

chck_found = np.random.randint(10, size=(2, 4))
for index, distances in enumerate(chck_found):
    print(index)
    print(distances)

0
[4 3 5 3]
1
[1 1 4 4]


In [74]:
# Check the scores  

text_features_new.shape

(1000, 768)

In [75]:
image_features_new.shape

(1000, 768)

In [76]:
(text_features_new * image_features_new).shape

(1000, 768)

In [77]:
np.matmul(text_features_new[999], image_features_new[999].T) * 100

0.6813209969550371

In [78]:
def compare_embeddings(logit_scale, img_embs, txt_embs):
  # normalized features
  image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
  text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)


  # logits_per_image = logit_scale * image_features @ text_features.t()


  logits_per_text = logit_scale * text_features @ image_features.t()

  # print("type: ", type(logits_per_text))
  
  return logits_per_text

In [79]:
# https://github.com/gpleiss/temperature_scaling
# CLIP Temperature scaler
logit_scale = clip_model.logit_scale.exp().float().to('cpu')

print(logit_scale)

language_logits = {}

 


language_logits["Arabic"] = compare_embeddings(logit_scale, torch.from_numpy(image_features_new), torch.from_numpy(text_features_new))
language_logits["Arabic"].shape

tensor(117.8218, grad_fn=<ToCopyBackward0>)


torch.Size([1000, 1000])

In [80]:
language_logits

{'Arabic': tensor([[  8.1210, -13.8103, -16.2311,  ..., -14.3900, -29.8156, -25.9899],
         [-20.9396,  14.0967, -21.5462,  ..., -21.3001, -22.7686, -19.4049],
         [-12.1298, -23.4617,  14.5536,  ..., -18.1756, -27.9193, -14.8679],
         ...,
         [ -0.1278,  -5.8619,  -6.6187,  ...,   6.8610, -17.4344,  -8.8681],
         [-19.4821, -17.4807, -21.4857,  ..., -20.9604,  14.7999, -18.6618],
         [-18.6599, -21.0871, -17.3422,  ..., -19.2643, -25.6973,  15.3250]],
        grad_fn=<MmBackward0>)}

In [81]:
type(language_logits)

dict

In [82]:
language_logits["Arabic"].shape

torch.Size([1000, 1000])

In [83]:
txt_logits = language_logits["Arabic"]

In [84]:
txt_logits.shape

torch.Size([1000, 1000])

In [85]:
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'matplotlib'

In [86]:
# def plot_heatmap(result_matrix):
#   height, width = result_matrix.shape
#   fig, ax = plt.subplots()
#   fig.set_size_inches(50,50)
#   im = ax.imshow(result_matrix)


#   # Create X & Y Labels
#   ax.set_xticks(np.arange(width))
#   ax.set_yticks(np.arange(height))
#   plt.xticks(rotation=90)

#   # ax.set_xticklabels(["Image {}".format(i) for i in range(width)])
#   # ax.set_yticklabels(["Text {}".format(i) for i in range(height)])

#   for i in range(height):
#     for j in range(width):
#         text = ax.text(j, i, result_matrix[i, j],
#                        ha="center", va="center", color='grey', size=5)

#   # fig.tight_layout()
#   plt.show()

# for lang, txt_logits in language_logits.items():
   
#   # Convert Logits into Softmax predictions
#   bot_range_ind = 400
#   top_range_ind = 500
#   probs = txt_logits[bot_range_ind:top_range_ind,bot_range_ind:top_range_ind].softmax(dim=-1).cpu().detach().numpy()

#   # Transpose so that each column is the softmax for each picture over the texts
#   probs = np.around(probs, decimals=2).T * 100

#   print("Language: {}".format(lang))
#   plot_heatmap(probs)

In [87]:
sorted_data[400+25]

{'caption_en': 'a half cut pizza on a plate on the table',
 'caption_ar': 'نصف بيتزا مقطعة على طبق على الطاولة',
 'image_name': 'COCO_val2014_000000127476.jpg',
 'id': 127476}

In [88]:
sorted_data[400+86]

{'caption_en': 'two pieces of pizza on a plate with a knife and fork laying on the plate',
 'caption_ar': 'قطعتان من البيتزا على طبق به سكين وشوكة على الطبق',
 'image_name': 'COCO_val2014_000000128180.jpg',
 'id': 128180}

In [89]:
# trial_1 = []

In [90]:
def compute_mrr(data, dataset, n):
    """Compute the MRR for the data based on n"""
    collect_rr = []
    pbar = tqdm(total=len(data), position=0, leave=True)

    # print("text_features")
    # print(text_features)
    # print("image_features")
    # print(image_features)

    # print("image_features shape: ")
    # print(image_features.shape)
    # print()
    # print("text_features shape: ")
    # print(text_features.shape)
    # found = np.matmul(text_features, image_features.T)
    found = np.matmul(text_features_new, image_features_new.T)

    # # instead: first shift the values of f so that the highest number is 0:
    # found -= np.max(found)
    # found_scalled = np.exp(found) / np.sum(np.exp(found)) # safe to do, gives the correct answer


    # found_scalled = softmax(found) # .softmax(dim=-1).cpu().detach().numpy()
    # print("print the matrix for the text features and the images featutes maltiplication found")

    # print(found)

    for index, distances in enumerate(found): # It return the rows, one by one

        pbar.update(1)
        # print()
        # print("index: ", index)
        # print("data[index]['id']: inside the loop", data[index]["id"])
        image_path = get_path_coco(data[index]["id"])
        # print(data[index]["id"])
        # print("New link")
        # print("image_path in compute_mrr ", image_path)
        # print("caption: ", data[index]["caption"])
        # print("distances")
        # print(distances)
        # print("n: ", n)

        
        collect_rr.append(new_rr(distances, image_path, dataset, n,index))


    pbar.close()
    print(100*"=")
    # trial_1 = collect_rr.copy()
    # print(collect_rr)
    
    return np.average(collect_rr)


def new_rr(distances, target_image, dataset, n):
    """Calculate the RR for the given target image"""
    image_paths = []

    # print("distances: ", distances)
    # print("type(distances): ", type(distances))
    idxs = distances.argsort()[-n:][::-1] # Get the indcies for the images distances based on n

    # print(idxs)
        
    # print(type(idxs))

    # idxs = distances.argsort()[-n:][::-1] # Get the indcies for the images distances based on n
    
    # print("distances.argsort(): ", distances.argsort())
    # print("distances.argsort()[-n:]: ", distances.argsort()[-n:])
    # print("distances.argsort()[-n:][::-1]: ", distances.argsort()[-n:][::-1])

    # print("idxs of the images from the top to the lower: ", idxs)
    # print("target_image: ", target_image)
    for idx in idxs:
        # print("'photos/val2014/' + dataset.get_image_name(idx): ", 'photos/val2014/' + dataset.get_image_name(idx))
        image_paths.append('photos/XTD10_dataset/' + dataset.get_image_name(idx))
        # image_paths.append(get_path_coco(data[idx]["id"]))

    # print("target_image: ", target_image)
    # print("image_paths: ", image_paths)

    if target_image in image_paths:

        return 1/(image_paths.index(target_image) + 1)
    else:
        # print("new_rr: ", 0)
        return 0


def internal_hits(distances, target_image, dataset, n):
    """Calculate the hits of the target images based on the existance of it or not"""
    image_paths = []
    idxs = distances.argsort()[-n:][::-1]

    if target_image in idxs:
        return 1
    else:
        return 0

def compute_hits(data, dataset, n):

    index_cnt = 0

    collect_rr = []

    pbar = tqdm(total=len(data), position=0, leave=True)

    found = np.matmul(text_features_new, image_features_new.T)

    for index, distances in enumerate(found):
        pbar.update(1)
        # image_path = get_path_coco(data[index]["id"])
        image_path = index # get_path_coco(data[index]["id"])
        
        collect_rr.append(internal_hits(distances, image_path, dataset, n))
        # collect_rr_testing.append(internal_hits(distances, image_path, dataset, n))
        break

    
    pbar.close()
    # print(len(collect_rr_testing))
    return np.average(collect_rr)

In [91]:
# def compute_mrr(data, dataset, n):
#     """Compute the MRR for the data based on n"""
#     collect_rr = []

#     found = np.matmul(text_features, image_features.T)


#     for index, cos_vlaues in enumerate(found):

#         image_path = get_image_path(data[index]["id"])

#         result = 0

#         image_paths = []

#         idxs = cos_vlaues.argsort()[-n:][::-1] 
        
#         for idx in idxs:
#             image_paths.append(get_image_path(idx))

#         if target_image in image_paths:

#             result = 1/(image_paths.index(target_image) + 1)

#         collect_rr.append(result)

#     return np.average(collect_rr)


In [92]:
def compute_mrr(data, dataset, n):
    """Compute the MRR for the data based on n"""
    collect_rr = []

    found = np.matmul(text_features_new, image_features_new.T)
    for index, distances in enumerate(found): # It return the rows, one by one

        image_path = get_path_coco(data[index]["id"])
        collect_rr.append(new_rr(distances, image_path, dataset, n,index))

        

    return np.average(collect_rr)

def new_rr(distances, target_image, dataset, n,index):
    """Calculate the RR for the given target image"""
    image_paths = []

    idxs = distances.argsort()[-n:][::-1] 
    
    # print("target_image: ", target_image)
    

    for idx in idxs:
        image_paths.append('photos/XTD10_dataset/' + dataset.get_image_name(idx))
    

    # print("image_paths: ", image_paths)
    
    if target_image in image_paths:

        return 1/(image_paths.index(target_image) + 1)
    else:
        return 0


In [93]:
# # image_encoder - ResNet or Vision Transformer
# # text_encoder - CBOW or Text Transformer
# # I[n, h, w, c] - minibatch of aligned images
# # T[n, l] - minibatch of aligned texts
# # W_i[d_i, d_e] - learned proj of image to embed
# # W_t[d_t, d_e] - learned proj of text to embed
# # t - learned temperature parameter
# # extract feature representations of each modality
# I_f = image_encoder(I) #[n, d_i]
# T_f = text_encoder(T) #[n, d_t]
# # joint multimodal embedding [n, d_e]
# I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
# T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# # scaled pairwise cosine similarities [n, n]
# logits = np.dot(I_e, T_e.T) * np.exp(t)
# # symmetric loss function
# labels = np.arange(n)
# loss_i = cross_entropy_loss(logits, labels, axis=0)
# loss_t = cross_entropy_loss(logits, labels, axis=1)
# loss = (loss_i + loss_t)/2


# Figure 3. Numpy-like pseudocode for the core of an implementa-
# tion of CLIP.

In [94]:
print('MRR@1:', compute_mrr(sorted_data, dataset, 1))

MRR@1: 0.673


In [95]:
print('MRR@5:', compute_mrr(sorted_data, dataset, 5))

MRR@5: 0.7522833333333333


In [96]:
print('MRR@10:', compute_mrr(sorted_data, dataset,10))

MRR@10: 0.7618686507936507


In [97]:
# print(compute_hits(sorted_data, dataset, 1)* 100)

In [98]:
# print(compute_hits(sorted_data, dataset, 5)* 100)

In [99]:
# print(compute_hits(sorted_data, dataset, 10)* 100)

## Evaluation based on Recall metric

In [100]:
image_features_new.shape

(1000, 768)

In [101]:
text_features_new.shape

(1000, 768)

In [102]:
image_features_new_pt = torch.from_numpy(image_features_new)

text_features_new_pt = torch.from_numpy(text_features_new)

text_to_image_map = torch.LongTensor(list(range(text_features_new.shape[0])))
print(text_to_image_map.shape) # .type(torch.int64)

print(text_to_image_map.unsqueeze(1).shape)

torch.Size([1000])
torch.Size([1000, 1])


In [103]:
torch.set_printoptions(precision=8)

In [104]:
# https://github.com/openai/CLIP/issues/115
import torch
from torchvision.datasets import CocoCaptions
import torch.utils.data as dutils
from typing import List
import clip

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')




def recall_at_k(k_vals, image_encodings,text_encodings,text_to_image_map):
    print("Encoding all data...")
 
    num_text = text_encodings.shape[0]
    
    # text-to-image recall
    print("Text-to-image recall...")


    dist_matrix = text_encodings @ image_encodings.T  # dist_matrix[i] gives logits for ith text

    inds = torch.argsort(dist_matrix, dim=1, descending=True)
    inds = inds.to(device)
    text_to_image_recall = []

    

    text_to_image_map = text_to_image_map.to(device)
    
    for k in k_vals:
        # Extract top k indices only
        topk = inds[:, :k]

        text_to_image_map_new = text_to_image_map.repeat(k, 1).t()

        correct = torch.eq(topk, text_to_image_map_new).any(dim=1)  #  value along dimension 1 (which typically corresponds to rows in a 2D tensor) ###### any(dim=1) >> check if True over the row 
        
        num_correct = correct.sum().item()

        text_to_image_recall.append(num_correct / num_text)

    print(text_to_image_recall)

    print("Done.")
    return text_to_image_recall

In [105]:
k_vals = [1,5,10]
t2i= recall_at_k(k_vals=k_vals, image_encodings=image_features_new_pt,text_encodings=text_features_new_pt,text_to_image_map=text_to_image_map)

print("Text-to-image Recall@K")

print("Returned value: ", t2i)
for k, x in zip(k_vals, t2i):
    print(k, " ", (x/100) * 100)
    # print(f" R@{k}: {100*x:.2f}%")


Encoding all data...
Text-to-image recall...
[0.673, 0.878, 0.948]
Done.
Text-to-image Recall@K
Returned value:  [0.673, 0.878, 0.948]
1   0.673
5   0.878
10   0.9479999999999998
