In [1]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

In [2]:
images = os.listdir("./Pictures/Camera Roll")

In [3]:
model = torchvision.models.resnet50(pretrained=True)



In [4]:
all_names = []
all_vecs = []
model.eval()
root = "./Pictures/Camera Roll"

In [5]:
# This is the pytorch convention of data augmentation
transform = transforms.Compose(
    [
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = [0.5, 0.5, 0.5],
        std = [0.5, 0.5, 0.5]
    ),
]
)

In [6]:
activation = ()
def get_activaton(name):
    def hook(model, input, output):
        activation[name] = output.detatch()
    return hook

In [7]:
model.avgpool.register_forward_hook(get_activaton("avgpool"))

<torch.utils.hooks.RemovableHandle at 0x24a52619720>

### CONVERT QUERY IMAGE TO VECTOR

In [8]:
with torch.no_grad():
    for i, file in enumerate(images):
        try:
            img = Image.open(root + file)
            img = transform(img)
            out = model(img[None, ...])
            vec = activation["avgpool"].numpy().squeeze()[None, ...]
            if all_vecs is None:
                all_vecs = vec
            else:
                all_vecs = np.vstack([all_vecs, vec])
            all_names.append(file)
        except:
            continue
        if i % 100 == 0 and i != 0:
            print(i, "done")

In [9]:
np.save("all_vecs.npy", all_vecs)
np.save("all_names.npy", all_names)

# FRONTEND USING STREAMLIT

In [10]:
import streamlit as st
import time
from scipy.spatial.distance import cdist

In [11]:
@st.cache_data
def read_data():
    all_vecs = np.load("all_vecs.npy")
    all_names = np.load("all_names.npy")
    return all_vecs, all_names



In [12]:
vecs, names = read_data()

2024-06-05 14:02:22.264 
  command:

    streamlit run C:\Users\Hp\anaconda3\lib\site-packages\ipykernel_launcher.py [ARGUMENTS]
2024-06-05 14:02:22.267 No runtime found, using MemoryCacheStorageManager


In [13]:
_, fcol2, _ = st.columns(3)

In [14]:
scol1, scol2 = st.columns(2)

In [15]:
ch = scol1.button("Start/change")

In [16]:
fs = scol2.button("find similar")

In [17]:
if ch:
    random_name = names[np.random.randint(len(names))]
    fcol2.image(Image.open("./Pictures/Camera Roll/" + random_name))
    st.session_state["disp_img"] = random_name
    st.write(st.session_state["disp_img"])
if fs:
    c1, c2, c3, c4, c5 = st.columns()
    idx = int(np.argwhere(names == st.session_state["disp_img"]))
    target_vec = vecs[idx]
    fcol2.image(Image.open("./Pictures/Camera Roll/" + st.session_state["disp_img"]))
    top5 = cdist(target_vec[None, ...], vecs).squeeze().argsort()[1:6]
    c1.image(Image.open("./Pictures/Camera Roll/" + names[top5[0]]))
    c2.image(Image.open("./Pictures/Camera Roll/" + names[top5[1]]))
    c3.image(Image.open("./Pictures/Camera Roll/" + names[top5[2]]))
    c4.image(Image.open("./Pictures/Camera Roll/" + names[top5[3]]))
    c5.image(Image.open("./Pictures/Camera Roll/" + names[top5[4]]))                         