# Database Initialization
- Run the whole notebook to generate the pickle file
- Locate the image directory
- Incase if resnet18 and yolov10n are not locally installed than it will initate the download for both

In [1]:
import torch
import os
import pickle
from PIL import Image
from torchvision import transforms
from ultralytics import YOLO

In [2]:
class ImageSearchAndTag:
    def __init__(self, dir_pth):

        # classes of trained yolov10
        self.classes = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 
          5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 
          10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 
          14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 
          20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 
          25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 
          30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 
          35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 
          39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 
          45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 
          51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 
          58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 
          64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 
          70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 
          76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}

        # images directory
        self.root_dir = dir_pth
        self.img_list = os.listdir(dir_pth)
        self.img_lib = {}

        # initializing cosine similarity function
        self.cos_sim = torch.nn.CosineSimilarity(dim=0)

        # initializing resnet18
        self.model_embd = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        self.model_embd = torch.nn.Sequential(*(list(self.model_embd.children())[:-1]))
        self.model_embd.eval()

        # initializing yolov10n
        self.model_obj_detect = YOLO("yolov10n.pt")

        self.collect_embd()

    def collect_embd(self,):
        """
        Function to fetch and store the tags and embeddings of
        whole image directory 
        """

        for img_id in self.img_list:
            if img_id[-4:] == 'jpeg' or img_id[-3:] == 'jpg' or img_id[-3:] == 'png':
                img_pth = os.path.join(self.root_dir, img_id)
                try:
                    img_embdings, tag = self.img_embd(img_pth)
                    if tag not in self.img_lib:
                        self.img_lib[tag] = {}
                    self.img_lib[tag][img_id] = img_embdings
                except Exception as e:
                    print("Error : ", e)

    def img_embd(self, file_path):
        """
        Function to generate the image embeddings 
        using resnet18 and image tags using yolov10n
        """

        input_image = Image.open(file_path)

        obj_output = self.model_obj_detect(input_image)
        cls = obj_output[0].boxes.cls
        cls = cls.to(torch.int).cpu().tolist()
        if len(cls) > 0:
            cls = list(set(cls))
            tag = ""
            for c in  cls:
                tag += str(self.classes[c]) + ":"
        else:
            tag = "NO_TAG"

        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        input_tensor = preprocess(input_image)
        input_batch = input_tensor.unsqueeze(0) 

        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
            self.model_embd = self.model_embd.to('cuda')

        with torch.no_grad():
            output = self.model_embd(input_batch)
            output = output.to(torch.float16)
            
        return output.squeeze(), tag
    
    def save_lib(self,):
        """
        Function to save the pickle file containing 
        Tag indexs and embedding of the images
        """

        with open('database_coco_test.pickle', 'wb') as handle:
            pickle.dump(self.img_lib, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# initiating the class
image_store = ImageSearchAndTag(dir_pth="/home/rion/image_search_engine/data/test2017")

In [None]:
# Saving the pickle file
image_store.save_lib()