In [1]:
import faiss
import os
import numpy as np
import pandas as pd

import torch
from torch import Tensor
from torchvision import models

from torchvision.transforms import Compose, transforms
from PIL import Image
import cv2

In [33]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### Loading Model and getting inference

In [34]:
model = models.resnet50(pretrained=True, progress=False)
for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Identity()
model.to(device)
model.eval()
print('', end='')



In [6]:
def as_numpy(val: Tensor) -> np.ndarray:
        return val.detach().cpu().numpy()

In [7]:
def transform(images: np.ndarray):
    transformed = [transforms.ToTensor()]
    composed = Compose(transformed)
    return composed(Image.fromarray(images[:, :, ::-1])).unsqueeze(0)

In [8]:
IMAGE_HOME = "/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Cat_flowing_down_a_sofa/"

In [9]:
img_list = []
for file in os.listdir(IMAGE_HOME):
    if file.endswith(".jpg"):
        img_list.append(file)

In [10]:
img_list[:2]

['g054_dgzi4ym.jpg', 'g054_dgzoaw3.jpg']

In [11]:
IMAGE_PATH = IMAGE_HOME + "g054_dgzg3g2.jpg"
img = cv2.imread(IMAGE_PATH)
imgt = transform(img)
# f = (3, width, height) values: 0-1
imgt = imgt.to(device)
with torch.no_grad():
    inference = as_numpy(model(torch.unsqueeze(imgt[0], 0)))

In [39]:
print(inference.shape, inference.dtype)
inference.reshape(1, -1)
inference.shape

(1, 2048) float32


(1, 2048)

### Getting Faiss scan result

In [14]:
INDEX_PATH =  "/nethome/kravicha3/.eva/0.1.5+dev/index/HNSW_dataindex.index"
index = faiss.read_index(INDEX_PATH)

In [15]:
k = 10
D, I = index.search(inference, k)

In [16]:
D = D.tolist()
I = I.tolist()
D,I

([[0.0,
   71.97640228271484,
   81.04901885986328,
   81.88858795166016,
   83.09184265136719,
   83.36123657226562,
   84.52255249023438,
   87.03155517578125,
   87.73091125488281,
   89.1760025024414]],
 [[9079, 8710, 7866, 2058, 2057, 7965, 2207, 4820, 4627, 2132]])

In [17]:
for i in list(zip(I[0], D[0])):
    print(i)

(9079, 0.0)
(8710, 71.97640228271484)
(7866, 81.04901885986328)
(2058, 81.88858795166016)
(2057, 83.09184265136719)
(7965, 83.36123657226562)
(2207, 84.52255249023438)
(4820, 87.03155517578125)
(4627, 87.73091125488281)
(2132, 89.1760025024414)


## Checking results using sqlite

In [18]:
import sqlite3
con = sqlite3.connect("/home/kravicha3/.eva/0.1.5+dev/eva_catalog.db")

In [19]:
c = con.cursor()

In [24]:
c.execute("SELECT * FROM '192111ccbbbfc5042415841dfaa9f90a' LIMIT 5;")
r = c.fetchall()
for i in r:
    print(i, end="\n\n")

(1, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_This_cat_plotting_to_kill_someone/g1327_czcqbl6.jpg')

(2, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_This_cat_plotting_to_kill_someone/g1327_czcu1y7.jpg')

(3, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_This_cat_plotting_to_kill_someone/g1327_czd2m0n.jpg')

(4, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_This_cat_plotting_to_kill_someone/g1327_czcrc83.png')

(5, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_This_cat_plotting_to_kill_someone/g1327_czd40us.jpg')



In [26]:
for i in I[0]:
    c.execute(f"SELECT * FROM '192111ccbbbfc5042415841dfaa9f90a' WHERE _row_id={i}")
    r= c.fetchall()
    print(r)

[(9079, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Cat_flowing_down_a_sofa/g054_dgzg3g2.jpg')]
[(8710, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_a_frog_riding_a_beetle/g382_d13crsr.jpg')]
[(7866, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Alexis_Ohanian_(CEO_and_founder_of_reddit)_holding_a_sign/g1333_cnorg0z.jpg')]
[(2058, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Vladimir_Putin_in_a_submarine_in_the_Black_Sea/g1191_cu7aerv.jpg')]
[(2057, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Vladimir_Putin_in_a_submarine_in_the_Black_Sea/g1191_cu794hl.jpg')]
[(7965, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Hawk_Owl_flying,_looking_into_camera/g1294_cvlk0dk.jpg')]
[(2207, '/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_This_man_dancing_at_a_wedding_(x-post_from__r_pics)/g097_dc

## Changing FAISS to OPQ and IVF

In [125]:
d = 2048
code_size = 32 # bytes
ncentroids = 512

coarse_quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ (coarse_quantizer, d,
                          512, code_size, 8)
index.nprobe = 5

In [129]:
def run_indexing(start_path = '.'):
    number_of_files = 0
    numpy_array = None
    for dirpath, dirnames, filenames in os.walk(start_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            
            if not os.path.islink(fp):
                if number_of_files!=0:
                    img = cv2.imread(fp)
                    if img is not None:
                        imgt = transform(img)
                        imgt = imgt.to(device)
                        with torch.no_grad():
                            inference = as_numpy(model(torch.unsqueeze(imgt[0], 0)))
                        numpy_array = np.append(numpy_array, inference, axis=0)
                        number_of_files += 1
                        break_flag = True
                else:
                    img = cv2.imread(fp)
                    imgt = transform(img)
                    imgt = imgt.to(device)
                    with torch.no_grad():
                        inference = as_numpy(model(torch.unsqueeze(imgt[0], 0)))
                    numpy_array = inference
                    number_of_files += 1
        if number_of_files>512:
            break

    return numpy_array, number_of_files

In [130]:
array, num = run_indexing('../../dataset/Reddit_Provenance_Datasets/data/')
print(num)

555


In [131]:
array.shape

(555, 2048)

In [132]:
print(index.is_trained)

False


In [133]:
index.train(array)



In [134]:
index.add(array)