# Assginment02: CLIP application
In this application, we will use CLIP to assess image similiarity

### Basic Imports

In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

from transformers import AutoProcessor, CLIPModel, AutoImageProcessor, AutoModel
import faiss

  from .autonotebook import tqdm as notebook_tqdm


### Device

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Model

In [3]:
#Load CLIP model and processor
processor_clip = AutoProcessor.from_pretrained("/data/lab/Qi Zimo/STA303-Assignment02/models/clip_processor")
model_clip = CLIPModel.from_pretrained("/data/lab/Qi Zimo/STA303-Assignment02/models/clip_model").to(device)
#Load DINOv2 model and processor
processor_dino = AutoImageProcessor.from_pretrained('/data/lab/Qi Zimo/STA303-Assignment02/models/dino_processor')
model_dino = AutoModel.from_pretrained('/data/lab/Qi Zimo/STA303-Assignment02/models/dino_model').to(device)

In [4]:
from tqdm import tqdm
#Retrieve all filenames
images = []
for dirpath, dirnames, filenames in os.walk(r'/data/lab/Qi Zimo/data/val2017/'):  
    for filename in filenames:  
        if filename.endswith('.jpg'):  
            images.append(os.path.join(dirpath, filename)) 

### Embedding

In [5]:
#normalizes embeddings and add to the index
def add_vector_to_index(embedding, index):
    vector = embedding.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)
    index.add(vector)

def extract_features_clip(image):
    
        return image_features

def extract_features_dino(image):
    
        return image_features.mean(dim=1)

#Create 2 indexes.
index_clip = faiss.IndexFlatL2(512)
index_dino = faiss.IndexFlatL2(768)

#extract features and store features in indexes
for image_path in tqdm(images, desc="Processing Images", unit="image"):
    img = Image.open(image_path).convert('RGB')
    with torch.no_grad():
        inputs = processor_clip(images=img, return_tensors="pt").to(device)
        clip_features = model_clip.get_image_features(**inputs)
    add_vector_to_index(clip_features,index_clip)

    with torch.no_grad():
        inputs = processor_dino(images=img, return_tensors="pt").to(device)
        outputs = model_dino(**inputs)
        dino_features = outputs.last_hidden_state
    add_vector_to_index(dino_features.mean(dim=1),index_dino)

#store the indexes locally
faiss.write_index(index_clip,"clip.index")
faiss.write_index(index_dino,"dino.index")

Processing Images: 100%|██████████| 5003/5003 [08:42<00:00,  9.58image/s]


### Search for best similarity

In [5]:
def normalizeL2(embeddings):
    vector = embeddings.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)
    return vector

#read the indexes
index_clip = faiss.read_index("clip.index")
index_dino = faiss.read_index("dino.index")
same_num = 0
diff_num = 0

dino_for_clip = []
clip_for_dino = []
dino_best = []
clip_best = []

for image_path in tqdm(images, desc="Assessing Images", unit="image"):
    img = Image.open(image_path).convert('RGB')
    with torch.no_grad():
        inputs = processor_clip(images=img, return_tensors="pt").to(device)
        clip_features = model_clip.get_image_features(**inputs)

    with torch.no_grad():
        inputs = processor_dino(images=img, return_tensors="pt").to(device)
        outputs = model_dino(**inputs)
        dino_features = outputs.last_hidden_state
        dino_features = dino_features.mean(dim=1)
    
    dino_features = normalizeL2(dino_features)
    clip_features = normalizeL2(clip_features)
    
    d_dino,i_dino = index_dino.search(dino_features,3)
    d_clip,i_clip = index_clip.search(clip_features,3)
    
    if (i_dino[0][1] != i_clip[0][1]):
        same_num = same_num +1
    elif i_dino[0][1] == i_clip[0][1]:
        diff_num = diff_num +1
        
    dino_image_id = i_dino[0][0:3]
    clip_image_id = i_clip[0][0:3]

    vec_clip_ori = index_clip.reconstruct(int(clip_image_id[0]))
    vec_dino_ori = index_dino.reconstruct(int(dino_image_id[0]))

    clip_vecs = []
    dino_vecs = []
    for i in range(3):
        clip_vecs.append(np.dot(index_clip.reconstruct(int(clip_image_id[i])),vec_clip_ori))
        dino_vecs.append(np.dot(index_dino.reconstruct(int(dino_image_id[i])),vec_dino_ori))
        
    dino_best.append(dino_vecs[1])
    clip_best.append(clip_vecs[1])
    dino_for_clip.append(np.dot(index_dino.reconstruct(int(clip_image_id[1])),index_dino.reconstruct(int(clip_image_id[0]))))
    clip_for_dino.append(np.dot(index_clip.reconstruct(int(dino_image_id[1])),index_clip.reconstruct(int(dino_image_id[0]))))
    
    image_dino = []
    image_clip = []
    for id in dino_image_id:
        image_dino.append(Image.open(images[id]))
    for id in clip_image_id:
        image_clip.append(Image.open(images[id]))
    
    fig, axs = plt.subplots(2, 4, figsize=(10, 5))  
    axs[0,0].imshow(img)  
    axs[0,0].set_title('image')  
    axs[0,0].axis('off')
    for i in range(3):
        axs[0,i+1].imshow(image_dino[i])  
        axs[0,i+1].set_title('DINO target'+ str(dino_vecs[i]),fontsize=8) 
        axs[0,i+1].axis('off') 

    axs[1,0].imshow(img)  
    axs[1,0].set_title('image')  
    axs[1,0].axis('off')
    for i in range(3):
        axs[1,i+1].imshow(image_clip[i])  
        axs[1,i+1].set_title('CLIP target'+ str(clip_vecs[i]),fontsize=8)
        axs[1,i+1].axis('off') 
            
    plt.savefig('results/'+ str(dino_image_id[0])+'result.jpg')
    plt.close()

Assessing Images: 100%|██████████| 5007/5007 [49:08<00:00,  1.70image/s]  


In [6]:
clip_dec_by_dino = [a-b for a,b in zip(dino_best,dino_for_clip)]
dino_dec_by_clip = [a-b for a,b in zip(clip_best,clip_for_dino)]
clip_dec_by_dino_mean = sum(clip_dec_by_dino)/len(clip_dec_by_dino)
dino_dec_by_clip_mean = sum(dino_dec_by_clip)/len(dino_dec_by_clip)

dino_mean = sum(dino_best) / len(dino_best)
clip_mean = sum(clip_best) / len(clip_best)
same_rate = same_num/len(images)
diff_rate = diff_num/len(images)
# dino_for_clip.count(0)
print("clip_dec_by_dino_mean",clip_dec_by_dino_mean)
print("dino_dec_by_clip_mean",dino_dec_by_clip_mean)
print("dino_mean",dino_mean)
print("clip_mean",clip_mean)
print("same_rate",same_rate)
print("diff_rate",diff_rate)

clip_dec_by_dino_mean 0.11313992072977848
dino_dec_by_clip_mean 0.07060249328065685
dino_mean 0.7181972342561719
clip_mean 0.8145475128000788
same_rate 0.8480127821050529
diff_rate 0.15198721789494707
