In [None]:
#import dependencies 
from PIL import Image
import torch
from torch import nn, optim
import glob
import os
import pandas as pd
import json
import numpy as np
import clip
from torch.utils.data import Dataset, DataLoader, BatchSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import random
from matplotlib.pyplot import imshow
import nltk, re, string, collections
from nltk.util import ngrams
import collections
import shutil
import pickle

In [2]:
'''Arrange your dataset such that you have two folders 
A folder of images 
A folder of jsons with the same name as images but json extension 
The structure iof jsons should be as follows 
The folder structure should look like this 
MAIN FOLDER 
    IMAGES
        img1.jpg
        img2.jpg
    JSONS
        img1.json
        img2.json
        

each of the json file should look like this 
[
    [
        "caption"
    ]
]
'''

'Arrange your dataset such that you have two folders \nA folder of images \nA folder of jsons with the same name as images but json extension \nThe structure iof jsons should be as follows \nThe folder structure should look like this \nMAIN FOLDER \n    IMAGES\n        img1.jpg\n        img2.jpg\n    JSONS\n        img1.json\n        img2.json\n        \n\neach of the json file should look like this \n[\n    [\n        "caption"\n    ]\n]\n'

In [None]:
#Function that links images to the captions by making a dictionary 
def return_processed_data(IMG_ROOT,JSON_ROOT):
    global device
    for imgs in os.listdir(IMG_ROOT):
        
        img_path = IMG_ROOT+"/"+imgs
        
        #removing corrupted images 
        try:

            image = Image.open(img_path)

        except:
            shutil.move(img_path,'/dump/'+imgs)

    img_paths = glob.glob(os.path.join(IMG_ROOT, "*.jpg"))
    new_img_paths=[]
    d = {}
    for i, img_path in enumerate(img_paths):
        name = img_path.split("/")[-1].split(".")[0]
        try:
            with open(os.path.join(JSON_ROOT, name+".json"), "r") as f:
                
                captions = json.load(f)
                temp = []
                for cap in captions:
                    # if "http" not in (cap[0]+ ' '+cap[1]) and len(cap[0]+ ' '+cap[1]) >= 8 and len(cap[0]+ ' '+cap[1]) <= 72:
                    #     temp.append(cap[0]+ ' '+cap[1])
                    temp.append(cap)
                    
                d[img_path] = cap
                new_img_paths.append(img_path)
                texts_temp = clip.tokenize(cap).to(device)
                # print(texts_temp)
        except Exception as e:
            print("Train Data ")
            print(img_path," ",d[img_path])
            d.pop(img_path)
            new_img_paths.remove(img_path)
            print(e)
            continue

    print(len(d))
    print(len(new_img_paths))

    d_new = {k: d[k] for k in new_img_paths}
    
    return d_new

In [None]:


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

BATCH_SIZE = 128
EPOCH = 5

IMG_ROOT_train = "train_folder/images"
JSON_ROOT_train = "train_folder/queries"

IMG_ROOT_val = "val_folder/images"
JSON_ROOT_val = "val_folder/queries"



d_train=return_processed_data(IMG_ROOT_train,JSON_ROOT_train)
d_val=return_processed_data(IMG_ROOT_val,JSON_ROOT_val)
    


Data Loading 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(device)
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
# model.load_state_dict(torch.load("/DATA/penamakuri1/Suyash/retrieval/clip/models/synthetic_best_model.pt"))
# model = nn.DataParallel(model)

class RetrievalDataset(Dataset):
    def __init__(self, data, preprocess):
        self.preprocess = preprocess
        self.img_paths = []
        self.captions = []
        for img_path, captions in data.items():
            for cap in captions:
                self.img_paths.append(img_path)
                self.captions.append(cap)
        self.processed_cache = {}
        for img_path in data:
            self.processed_cache[img_path] = self.preprocess(Image.open(img_path))
        self.img_paths_set = list(data.keys())
        self.path2label = {path: self.img_paths_set.index(path) for path in self.img_paths_set}

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = self.processed_cache[img_path]
        caption = self.captions[idx]
        label = self.path2label[img_path]
        return image, caption, label

train_dataset = RetrievalDataset(d_train, preprocess)
val_dataset = RetrievalDataset(d_val, preprocess)


class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

train_labels = torch.tensor([item[2] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)

val_labels = torch.tensor([item[2] for item in val_dataset])
val_sampler = BalancedBatchSampler(val_labels, BATCH_SIZE, 1)
val_dataloader = DataLoader(val_dataset, batch_sampler=val_sampler)


for i, item in enumerate(train_sampler):
#     print(item)
#     print(len(item))
    labels = []
    for idx in item:
        label = train_dataset[idx][2]
        labels.append(label)
    break

print(len(labels), len(set(labels)))

for batch in train_dataloader:
    imgs, txts, labels = batch
    print(imgs.shape)
    print(len(txts))
    print(labels)
    print(labels.shape)
    print(torch.unique(labels).shape)
    break







Model Training 

In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

if device == "cpu":
    model.float()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)

In [None]:
best_val_loss = 1e5
best_ep = -1
for epoch in range(EPOCH):
    print(f"running epoch {epoch}, best test loss {best_te_loss} after epoch {best_ep}")
    step = 0
    tr_loss = 0
    model.train()
    pbar = tqdm(train_dataloader, leave=False)
    for batch in pbar:
        step += 1
        optimizer.zero_grad()

        images, texts, _ = batch
        images = images.to(device)
        texts = clip.tokenize(texts).to(device)
#         print(images.shape, texts.shape)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        total_loss.backward()
        tr_loss += total_loss.item()
        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            scheduler.step()
            clip.model.convert_weights(model)
        pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
    tr_loss /= step

    step = 0
    val_loss = 0
    with torch.no_grad():
        model.eval()
        val_pbar = tqdm(val_dataloader, leave=False)
        for batch in val_pbar:
            step += 1
            images, texts, _ = batch
            images = images.to(device)
            texts = clip.tokenize(texts).to(device)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            val_loss += total_loss.item()
            val_pbar.set_description(f"test batchCE: {total_loss.item()}", refresh=True)
        val_loss /= step

    if val_loss < best_val_loss:
        best_te_loss = val_loss
        best_ep = epoch
        print(best_ep)
        torch.save(model.state_dict(), "PATH_TO_SAVE_MODEL/Best.pt")
    print(f"epoch {epoch}, tr_loss {tr_loss}, te_loss {val_loss} , best_epoch was {best_ep}, best te loss {best_val_loss}")
torch.save(model.state_dict(), "PATH_TO_SAVE_MODEL/Last.pt")

Using the saved best model we need to create both text and images embeddings for out test dataset so that we can evaluate the modle performance and also perform inferencing on the model

In [None]:
IMG_ROOT = "test_images"
JSON_ROOT = "test_queries"

img_paths = glob.glob(os.path.join(IMG_ROOT, "*.jpg"))

d = {}
for i, img_path in enumerate(img_paths):
    name = img_path.split("/")[-1].split(".")[0]
    try:
        with open(os.path.join(JSON_ROOT, name+".json"), "r") as f:
            
            captions = json.load(f)
            temp = []
            for cap in captions:
                # if "http" not in (cap[0]+ ' '+cap[1]) and len(cap[0]+ ' '+cap[1]) >= 8 and len(cap[0]+ ' '+cap[1]) <= 72:
                #     temp.append(cap[0]+ ' '+cap[1])
                temp.append(cap)
            d[img_path] = cap
            
    except Exception as e:
        # print(e)
        continue

print(len(d))

device = "cpu"


model.load_state_dict(torch.load("Best.pt"))

first_item = next(iter(d.items()))

# Print the first key-value pair
print(first_item)

# print(brrrs)
embeddings_dictionary_text={}
embeddings_dictionary_images={}
for i in tqdm(d.keys()):
    text = clip.tokenize(d[i][0]).to(device)
    print(i)
    print(d[i][0])
    

    image = preprocess(Image.open(i)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        embeddings_dictionary_images[i]=image_features
        text_features = model.encode_text(text)
        embeddings_dictionary_text[i]=text_features

with open('text_emeddings.pkl', 'wb') as file: 
      
    # A new file will be created 
    pickle.dump(embeddings_dictionary_text, file) 

with open('image_embeddings.pkl', 'wb') as file: 
      
    # A new file will be created 
    pickle.dump(embeddings_dictionary_images, file) 

embeddings are created and saved in pickle file they can be loaded now to calculate Recall@1, Recall@5 and Recall@10 fpr both image-to-text and text-to-image retrieval models 

In [None]:
import pickle
with open('image_embeddings.pkl', 'rb') as file: 
      
    # Call load method to deserialze 
    image_dictionary = pickle.load(file) 
    
with open('text_embeddings.pkl', 'rb') as file: 
      
    # Call load method to deserialze 
    text_dictionary = pickle.load(file) 
    
image_features1=[]
text_features1=[]
for k in image_dictionary.keys():
    image_features1.append(image_dictionary[k])
    text_features1.append(text_dictionary[k])
image_features=[]
text_features=[]
for i in image_features1:
    image_features.append(i[0].cpu().numpy())
for i in text_features1:
    text_features.append(i[0].cpu().numpy())
      
      
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(text_features, image_features):
    similarity_matrix = cosine_similarity(text_features, image_features)
    return similarity_matrix


In [None]:
import numpy as np

def rank_similarities(similarity_matrix):
    ranked_indices = np.argsort(-similarity_matrix, axis=1)  # Descending order
    return ranked_indices

In [None]:
def recall_at_k(ranked_indices, ground_truth_indices, k):
    recalls = []
    for i, ground_truth in enumerate(ground_truth_indices):
        top_k = ranked_indices[i, :k]
        print("ground truth ",ground_truth)
        recall = np.intersect1d(top_k, ground_truth).size / len(ground_truth)
        # print(recall)
        
        recalls.append(recall)
    return np.mean(recalls)

def calculate_recall(similarity_matrix, ground_truth_indices_text_to_image, ground_truth_indices_image_to_text, k):
    ranked_indices_text_to_image = rank_similarities(similarity_matrix)
    ranked_indices_image_to_text = rank_similarities(similarity_matrix.T)
    
    recall_text_to_image = recall_at_k(ranked_indices_text_to_image, ground_truth_indices_text_to_image, k)
    recall_image_to_text = recall_at_k(ranked_indices_image_to_text, ground_truth_indices_image_to_text, k)
    
    return recall_text_to_image, recall_image_to_text


In [None]:
l=[]
for i in range(len(image_features)):
    l.append([i])

In [None]:

similarity_matrix = calculate_similarity(text_features, image_features)
k = 10  # Example for Recall@10
recall_text_to_image, recall_image_to_text = calculate_recall(similarity_matrix,l,l, k)

print(f"Recall@{k} for Text-to-Image: {recall_text_to_image}")
print(f"Recall@{k} for Image-to-Text: {recall_image_to_text}")