In [5]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os


model = models.resnet50(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])
model.eval()


preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),

])

def extract_features(image_path):
    img = Image.open(image_path).convert("RGB")
    img_tensor = preprocess(img).unsqueeze(0)
    with torch.no_grad():
        features = model(img_tensor)
    return features.squeeze().numpy()



In [None]:
%pip install annoy

Collecting annoy
  Downloading annoy-1.17.3.tar.gz (647 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/647.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/647.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m647.5/647.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: annoy
  Building wheel for annoy (setup.py) ... [?25l[?25hdone
  Created wheel for annoy: filename=annoy-1.17.3-cp311-cp311-linux_x86_64.whl size=553317 sha256=110f8f5f57b546dcb6febc792065c4a3191edd10b63061413c92f8a1c34bf3d7
  Stored in directory: /root/.cache/pip/wheels/33/e5/58/0a3e34b92bedf09b4c57e37a63ff395ade6f6c1099ba59877c
Successfully built annoy
Installing collected packages: annoy
Successfully installed annoy-1.17.3


In [6]:
from annoy import AnnoyIndex
import numpy as np

feature_dim = 2048
annoy_index = AnnoyIndex(feature_dim, 'euclidean')

image_paths = []
for idx, image_file in enumerate(os.listdir("dataset")):
    path = os.path.join("dataset", image_file)
    features = extract_features(path)
    annoy_index.add_item(idx, features)
    image_paths.append(path)

annoy_index.build(10)  # 10 trees
annoy_index.save('image_index.ann')

True

In [7]:
def search_similar_images(query_path, top_k=5):
    query_vec = extract_features(query_path)
    indices = annoy_index.get_nns_by_vector(query_vec, top_k)
    return [image_paths[i] for i in indices]

results = search_similar_images("dataset/togepi.png")
for res in results:
    print(res)


dataset/togepi.png
dataset/vanillish.png
dataset/staryu.png
dataset/seaking.png
dataset/vanillite.png
