## Import modules

In [1]:
import os
import torch
import pandas as pd
from PIL import Image
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from random import random, randrange
from torchvision import transforms,models
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import clip
from copy import deepcopy

## Import clip model

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load("ViT-B/32",device="cpu")

## Creating images lists

In [67]:
CLEAN_IMAGES_CSV_PATH = os.path.join("data","DAM")
reference_images = [os.path.join(CLEAN_IMAGES_CSV_PATH, img)for img in os.listdir(CLEAN_IMAGES_CSV_PATH)]

TEST_IMAGES_PATH = os.path.join("data","test_image_headmind")
img_list = [os.path.join(TEST_IMAGES_PATH,i) for i in os.listdir(TEST_IMAGES_PATH)]

## Creating Answers Directory

In [68]:
csv_all = pd.read_csv("results17-32.csv",header=None, sep=", ")
csv_arthur = pd.read_csv("answer65-80.csv", header=None, sep=", ")
csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])
csv_all.head()
csv_all.columns = ['0','1']

  csv_all = pd.read_csv("results17-32.csv",header=None, sep=", ")
  csv_arthur = pd.read_csv("answer65-80.csv", header=None, sep=", ")
  csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])
  csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])
  csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])


In [69]:
print(csv_all)

                  0                               1
0     M0538OCALM35R  image-20210928-102713-12d2869d
1     M0505SLOIM989  image-20210928-102718-2474636a
2     M0565ONGEM50P  image-20210928-102721-8eaea48f
3     M9203UWDIM59E  image-20210928-102725-7e28b44c
4     M9204UMOAM918  image-20210928-102725-7e28b44c
..              ...                             ...
12  SOSTELAIR1SZJ86         MicrosoftTeams-image_46
13  SOSTELAIR1807YB         MicrosoftTeams-image_48
14    DSGTS6UXR10A0         MicrosoftTeams-image_50
15    S5652CBAAM41G         MicrosoftTeams-image_53
16    S5652CCEHM900         MicrosoftTeams-image_54

[96 rows x 2 columns]


In [80]:
answers = {}
for _, row in csv_all.iterrows():
    if row['1'] in answers.keys():
        answers[row['1']].append(row['0'])
    else:
        answers[row['1']] = [row['0']]

## Creating embeddings for the referrence images

In [71]:
reference_features = []
for img in tqdm(reference_images):
    image = preprocess(Image.open(img)).unsqueeze(0)
    with torch.no_grad():
        reference_features.append(model.encode_image(image).flatten())

100%|██████████| 2766/2766 [05:05<00:00,  9.06it/s]


## Creating embeddings for the client images

In [72]:
client_features = []
for img in tqdm(img_list):
    image = preprocess(Image.open(img)).unsqueeze(0)
    with torch.no_grad():
        client_features.append(model.encode_image(image).flatten())

100%|██████████| 80/80 [00:19<00:00,  4.13it/s]


## Retrieving the top 5 most similar items

In [77]:
client_embedding_tensor = torch.stack(client_features).squeeze(1)
embedding_tensor = torch.stack(reference_features).squeeze(1)

client_embedding_tensor_normalized = client_embedding_tensor / client_embedding_tensor.norm(dim=1, keepdim=True)
embedding_tensor_normalized = embedding_tensor / embedding_tensor.norm(dim=1, keepdim=True)

cosine_similarities = torch.mm(client_embedding_tensor_normalized, embedding_tensor_normalized.t())

closest_indices = torch.argsort(cosine_similarities, dim=1,descending=True)[:,:10]
closest_indices

tensor([[2017, 1772, 2115, 2074, 2163, 2154, 2109, 2110, 1681, 1658],
        [1896, 2163, 1735, 1895, 1873, 1654, 2205, 1743, 2089, 2464],
        [2163, 2085, 2430, 2763, 1638, 1978, 1613, 2190, 1800, 2646],
        [2280, 2263, 2243, 2269, 2430, 2676, 2270, 2583, 2582, 2255],
        [2676, 2243, 2263, 2247, 2245, 2421, 2680, 2660, 2583, 2335],
        [1492, 1426, 1435, 1500, 1510, 1467, 1472, 1433, 1327, 1431],
        [1864, 2536, 2267, 2236, 2163, 2190, 1982, 2464, 1571, 2280],
        [1735, 1654, 2464, 2628, 1659, 2086, 1641, 2040, 1645, 2523],
        [2628, 2644, 1583, 1735, 2646, 2280, 1353, 2684, 2685, 2625],
        [1934, 1828, 1827, 1926, 2625, 2644, 2289, 2290, 1549, 1958],
        [ 659,  660, 1076,  854, 1169, 2678, 1558, 1571, 1548, 1581],
        [ 259, 1302, 2277, 1571, 1546, 2280, 2628, 2678, 2278, 1195],
        [1036, 1034, 1024, 1355, 1035,  259, 1354,  665, 2439, 1022],
        [1354, 1022, 1024, 1036, 1023, 2664, 1502, 1027, 1355, 2489],
        [1717, 2076,

In [39]:
closest_indices[0]

tensor([1660, 1663, 1666, 1780, 1756])

In [78]:
guesses = {}

for i, img in enumerate(img_list):
    file_name = img.split("\\")[-1]
    img_id = file_name.split(".")[0]
    guesses[img_id] = [reference_images[ind] for ind in closest_indices[i].tolist()]

In [75]:
guesses

{'image-20210928-102713-12d2869d': ['data\\DAM\\M500SJAAUXM334.jpeg',
  'data\\DAM\\M0538OCALM35R.jpeg',
  'data\\DAM\\M531SJAUGXM334.jpeg',
  'data\\DAM\\M505SOAAUXM43R.jpeg',
  'data\\DAM\\M565SOAAUXM830.jpeg'],
 'image-20210928-102718-2474636a': ['data\\DAM\\M0566PAWAXM90B.jpeg',
  'data\\DAM\\M565SOAAUXM830.jpeg',
  'data\\DAM\\M0531NWDDM900.jpeg',
  'data\\DAM\\M0566PAWAXM85B.jpeg',
  'data\\DAM\\M0566JAWAXM85B.jpeg'],
 'image-20210928-102721-8eaea48f': ['data\\DAM\\M565SOAAUXM830.jpeg',
  'data\\DAM\\M505SOLAAXM79B.jpeg',
  'data\\DAM\\S0110ONMJM43F.jpeg',
  'data\\DAM\\VRB44560N0.jpeg',
  'data\\DAM\\M0505JAWAXM41G.jpeg'],
 'image-20210928-102725-7e28b44c': ['data\\DAM\\M9204ULERXM423.jpeg',
  'data\\DAM\\M9203UMOSM46E.jpeg',
  'data\\DAM\\M9203CMOSM46E.jpeg',
  'data\\DAM\\M9203UWDIM59E.jpeg',
  'data\\DAM\\S0110ONMJM43F.jpeg'],
 'image-20210928-102729-f53d9faf': ['data\\DAM\\S5652CWGSM919.jpeg',
  'data\\DAM\\M9203CMOSM46E.jpeg',
  'data\\DAM\\M9203UMOSM46E.jpeg',
  'data\\DAM

In [82]:
accuracy = 0
nb_guess = 0
for answer_key in answers.keys():
    found = 0
    if answer_key in guesses.keys():
        nb_guess += 1
        for value_guess in guesses[answer_key]:
            for value_answ in answers[answer_key]:
                if value_answ in value_guess and not found:
                    accuracy += 1
                    found = 1

accuracy/nb_guess*100

57.14285714285714