In [4]:
%cd /content/drive/MyDrive/BLIP/BLIP

/content/drive/MyDrive/BLIP/BLIP


In [5]:
!pip install annoy

Collecting annoy
  Downloading annoy-1.17.3.tar.gz (647 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m647.5/647.5 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: annoy
  Building wheel for annoy (setup.py) ... [?25l[?25hdone
  Created wheel for annoy: filename=annoy-1.17.3-cp310-cp310-linux_x86_64.whl size=552447 sha256=8e8752baa6ac79c37c2fb7a707408b0341f22bb77a3be865926f58248d9c9525
  Stored in directory: /root/.cache/pip/wheels/64/8a/da/f714bcf46c5efdcfcac0559e63370c21abe961c48e3992465a
Successfully built annoy
Installing collected packages: annoy
Successfully installed annoy-1.17.3


In [6]:
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import json
from annoy import AnnoyIndex

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
class MyAnnoy:
  def __init__(self, model, image_size, json_path, metric, ntrees, feature_shape):
    self.model = model
    self.image_size = image_size
    self.dict_id2image_path = self.load_json(json_path)
    self.metric = metric
    self.ntrees = ntrees
    self.feature_shape = feature_shape

  #parse json to dictionary id : image_path
  def load_json(self, json_path):
    with open(json_path, 'r') as f:
      js = json.loads(f.read())

    return {int(id): image_path for image_path, id in js.items()}

  def buildAnnoyIndex(self, save_path):
    annoy_index = AnnoyIndex(self.feature_shape, self.metric)

    for i in len(self.dict_id2image_path):
      image_path = self.dict_id2image_path[i]
      image_input = self.load_image(image_path, self.image_size)

      output_feature = self.model.image_feature[0]

      annoy_index.add_item(i, output_feature)

    annoy_index.build(self.ntrees)
    annoy_index.save(save_path)

  def load_annoy(self, annoy_path):
    annoy_index = AnnoyIndex(self.feature_shape, self.metric)
    annoy_index.load(annoy_path)
    return annoy_index

  def annoy_image_search(self, annoy_index, image_path, topk):
    image_input = self.load_image(image_path, self.image_size)

    output_feature = self.model.image_feature(image_input)[0]
    index_image = annoy_index.get_nns_by_vector(output_feature, topk)

    return index_image

  def annoy_text_image_search(self, annoy_index, text, topk):
    output_feature = self.model.text_feature(text)[0]
    index_image = annoy_index.get_nns_by_vector(output_feature, topk)

    return index_image

  @staticmethod
  def load_image(image_path, image_size):
    image = Image.open(image_path)

    transform = transforms.Compose([
    transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])
    image = transform(image).unsqueeze(0).to(device)
    return image
