In [12]:
import faiss
import os
import numpy as np
import pandas as pd

import torch
from torch import Tensor
from torchvision import models

from torchvision.transforms import Compose, transforms
from PIL import Image
import cv2

### Loading Model and getting inference

In [62]:
model = models.resnet50(pretrained=True, progress=False)
for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Identity()
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [63]:
def as_numpy(val: Tensor) -> np.ndarray:
        return val.detach().cpu().numpy()

In [64]:
def transform(images: np.ndarray):
    transformed = [transforms.ToTensor()]
    composed = Compose(transformed)
    return composed(Image.fromarray(images[:, :, ::-1])).unsqueeze(0)

In [65]:
IMAGE_HOME = "/nethome/kravicha3/aryan/project/dataset/Reddit_Provenance_Datasets/data/_Cat_flowing_down_a_sofa/"

In [66]:
img_list = []
for file in os.listdir(IMAGE_HOME):
    if file.endswith(".jpg"):
        img_list.append(file)

In [67]:
img_list

['g054_dgzi4ym.jpg',
 'g054_dgzoaw3.jpg',
 'g054_dgzevtg.jpg',
 'g054_dgzarub.jpg',
 'g054_dgzhqrr.jpg',
 'g054_dgzfrvk.jpg',
 'g054_dgzb6kd.jpg',
 'g054_dgzdnl4.jpg',
 'g054_dgzhtgu.jpg',
 'g054_dh0b7ud.jpg',
 'g054_dgzg97e.jpg',
 'g054_dgzgtd7.jpg',
 'g054_dgzaxjz.jpg',
 'g054_dgzikzk.jpg',
 'g054_dgzc1sv.jpg',
 'g054_dgzd6s8.jpg',
 'g054_dgzfexv.jpg',
 'g054_dgzep67.jpg',
 'g054_dgzg451.jpg',
 'g054_dgzh9ps.jpg',
 'g054_root.jpg',
 'g054_dgzg3g2.jpg',
 'g054_dh0ahk3.jpg',
 'g054_dgzd0fh.jpg',
 'g054_dgzb67i.jpg']

In [68]:
IMAGE_PATH = IMAGE_HOME + "g054_dgzg3g2.jpg"
img = cv2.imread(IMAGE_PATH)

imgt = transform(img)

In [69]:
# f = (3, width, height) values: 0-1
with torch.no_grad():
    inference = as_numpy(model(torch.unsqueeze(imgt[0], 0)))

In [70]:
print(inference.shape, inference.dtype)
inference.reshape(1, -1)
inference.shape

(1, 2048) float32


(1, 2048)

### Getting Faiss scan result

In [71]:
INDEX_PATH =  "/nethome/kravicha3/.eva/0.1.5+dev/index/HNSW_dataindex.index"
index = faiss.read_index(INDEX_PATH)

In [79]:
k = 10
D, I = index.search(inference, k)

In [80]:
D = D.tolist()
I = I.tolist()
D,I

([[0.0,
   71.97640228271484,
   81.04901885986328,
   81.88858795166016,
   83.09184265136719,
   83.36123657226562,
   84.52255249023438,
   87.03155517578125,
   87.73091125488281,
   89.1760025024414]],
 [[9079, 8710, 7866, 2058, 2057, 7965, 2207, 4820, 4627, 2132]])

In [83]:
for i in list(zip(, I[0])):
    print(i)

(0.0, 9079)
(71.97640228271484, 8710)
(81.04901885986328, 7866)
(81.88858795166016, 2058)
(83.09184265136719, 2057)
(83.36123657226562, 7965)
(84.52255249023438, 2207)
(87.03155517578125, 4820)
(87.73091125488281, 4627)
(89.1760025024414, 2132)
