In [3]:
import torch
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
from PIL import Image
from scipy.spatial import distance

from pymilvus import connections, db
from pymilvus import CollectionSchema, FieldSchema, DataType

import numpy as np 

### feature extractor

In [4]:

class FeatureExtractor_CNN:
    def __init__(self):
        self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.model.eval()  # Set model to evaluation mode

        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def preprocess_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0)  # Add batch dimension
        return image_tensor

    def extract_features(self, image_tensor):
        with torch.no_grad():
            features = self.model(image_tensor).squeeze(0).cpu().numpy()
        return features

    def cosine_similarity(self, features1, features2):
        similarity = 1 - distance.cosine(features1, features2)
        return similarity

    def cosine_similarity_images(self, image1_path, image2_path):
        img1_tensor = self.preprocess_image(image1_path)
        img2_tensor = self.preprocess_image(image2_path)
        
        features1 = self.extract_features(img1_tensor)
        features2 = self.extract_features(img2_tensor)

        return self.cosine_similarity(features1, features2)

    def preprocess_extract_feature(self,img_path:str) -> np.ndarray:
        img_tensor = self.preprocess_image(img_path)
        features = self.extract_features(img_tensor)
        return features




In [5]:
extractor = FeatureExtractor_CNN()
image1_path = "/caltech-101/101_ObjectCategories/buddha/image_0006.jpg"
image2_path = "/caltech-101/101_ObjectCategories/ant/image_0020.jpg"

try:
    # similarity_score = extractor.cosine_similarity_images(image1_path, image2_path)
    # print(f"Cosine similarity between images: {similarity_score}")
    embed_val = extractor.preprocess_extract_feature(image1_path)
    print(embed_val.shape)

except Exception as e:
    print(f"Error: {e}")


(1000,)


## UI

In [None]:
# import gradio as gr

# def show_image_and_text(image, text):
#   """Displays the uploaded image and entered text."""
#   return image, text

# # Define the interface with input and output components
# interface = gr.Interface(
#   fn=show_image_and_text,
#   inputs=[
#     gr.Image(type="pil"),  # Allow uploading an image
#     gr.Textbox(label="Enter text:")  # Allow entering text
#   ],
#   outputs=[
#     gr.Image(type="pil", label="Uploaded image"),  # Display the uploaded image
#     gr.Textbox(label="Your text:")  # Display the entered text
#   ],
#   title="HeyBagh",
#   description="Upload a photo of object that you wish to search",
# )

# # Launch the interface
# interface.launch()


### Milvus Client

In [6]:


conn = connections.connect(host="127.0.0.1", port=19530)

In [None]:
# connections.disconnect("")

In [None]:
# Create Database
# database = db.create_database("heybagh_db")


In [7]:

db.list_database()

['default', 'heybagh_db']

In [67]:
# create fields schema  



"""
schema for vector-db
{
    img_id: INT
    img_cls_name: STR
    img_embeddings: VECTOR-FLOAT
}
"""

img_id = FieldSchema(name='img_id', dtype=DataType.INT64, description='assigned Image ID', is_primary=True)

img_cls_name = FieldSchema(name='img_cls_name', dtype=DataType.VARCHAR, description='Class-Name of the Image it belongs to', default_value='unknown', max_length=80)

img_relative_path = FieldSchema(name='img_rel_path', dtype=DataType.VARCHAR, description='Relative path of the Image', default_value='unknown', max_length=1000)

img_embeddings = FieldSchema(name='img_embeddings', dtype=DataType.FLOAT_VECTOR, dim=1000, description='Extracted embeddings of Image')

In [68]:
# create colleection schema

schema =  CollectionSchema(fields=[img_id, img_cls_name, img_relative_path, img_embeddings],
                           description='heybagh image search',
                           enable_dynamic_field=True,
                           auto_id=True)

In [10]:
from pymilvus import Collection

# Create collection
collection_name = 'heybagh_caltech101_imgs'

hey_bagh_caltech_collect = Collection(name=collection_name,
                        schema=schema,
                        shards_num=2,)

NameError: name 'schema' is not defined

In [70]:
from pymilvus import utility
print("Check if collection exists: ", utility.has_collection(collection_name))

# collection = Collection(collection_name)

print(hey_bagh_caltech_collect.schema)

print(hey_bagh_caltech_collect.indexes)

Check if collection exists:  True
{'auto_id': True, 'description': 'heybagh image search', 'fields': [{'name': 'img_id', 'description': 'assigned Image ID', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'img_cls_name', 'description': 'Class-Name of the Image it belongs to', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 80}}, {'name': 'img_rel_path', 'description': 'Relative path of the Image', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 1000}}, {'name': 'img_embeddings', 'description': 'Extracted embeddings of Image', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 1000}}], 'enable_dynamic_field': True}
[]


In [71]:
### Prepare Data for insert into collection

In [39]:
# Insert data into collection to create index
root_path_to_caltech_101_dataset = "/datasets/caltech-101/101_ObjectCategories"

import os
from tqdm import tqdm

def traverse_directories(root_dir):
    """Traverses all directories under the given root directory and yields file and subdirectory information."""
    # NOTE : Returning relative_path, subdirectory_name, filename ---> use this with "root_path" to get full_path to the image
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            file_path = os.path.join(dirpath, filename)
            relative_path = os.path.relpath(file_path, root_dir)  # Calculate relative path
            # print(relative_path)
            yield filename, relative_path, os.path.basename(dirpath)  # Yield filename, full path, subdirectory name



In [None]:
# Prepare a data for collection
import pandas as pd



entity_img_cls_name = []
entity_img_rel_path     = []
entity_img_embedding = []

for filename, rel_path,subdir_name in traverse_directories(root_path_to_caltech_101_dataset):
    entity_img_cls_name.append(subdir_name)
    entity_img_rel_path.append(rel_path)
    full_path_img = root_path_to_caltech_101_dataset + '/' + rel_path
    entity_img_embedding.append(extractor.preprocess_extract_feature(full_path_img))

df =  pd.DataFrame({"img_cls_name": entity_img_cls_name,
                   "img_rel_path": entity_img_rel_path,
                   "img_embeddings": entity_img_embedding
                   })


In [72]:
entities = [df["img_cls_name"].to_list(), 
            df["img_rel_path"].to_list(), 
            df["img_embeddings"].to_list()]

hey_bagh_caltech_collect.insert(entities)

(insert count: 8676, delete count: 0, upsert count: 0, timestamp: 448239057656086531, success count: 8676, err count: 0)

In [73]:
hey_bagh_caltech_collect.flush()

### Collection info

In [8]:
from pymilvus import utility

utility.list_collections()

['heybagh_caltech101_imgs']

In [75]:
from pymilvus import Collection
collection = Collection("heybagh_caltech101_imgs")  # Get an existing collection.

print(collection.schema)                # Return the schema.CollectionSchema of the collection.
print(collection.description)           # Return the description of the collection.
print(collection.name)                  # Return the name of the collection.
print(collection.is_empty)              # Return the boolean value that indicates if the collection is empty.
print(collection.num_entities)          # Return the number of entities in the collection.
print(collection.primary_field)         # Return the schema.FieldSchema of the primary key field.
print(collection.partitions)            # Return the list[Partition] object.
print(collection.indexes)               # Return the list[Index] object.
# print(collection.properties)		# Return the expiration time of data in the collection.


{'auto_id': True, 'description': 'heybagh image search', 'fields': [{'name': 'img_id', 'description': 'assigned Image ID', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'img_cls_name', 'description': 'Class-Name of the Image it belongs to', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 80}}, {'name': 'img_rel_path', 'description': 'Relative path of the Image', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 1000}}, {'name': 'img_embeddings', 'description': 'Extracted embeddings of Image', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 1000}}], 'enable_dynamic_field': True}
heybagh image search
heybagh_caltech101_imgs
False
8676
{'name': 'img_id', 'description': 'assigned Image ID', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}
[{"name":"_default","collection_name":"heybagh_caltech101_imgs","description":""}]
[]


### Update collection with relative paths

### Creating Index

In [76]:
# index parameters

index = {
    "index_type": "IVF_FLAT",
    "metric_type": "COSINE",
    "params": {"nlist": 512},
}   

In [22]:
# hey_bagh_caltech_collect.drop_index()

In [77]:
hey_bagh_caltech_collect.create_index(field_name='img_embeddings', index_params=index)

Status(code=0, message=)

In [78]:
utility.index_building_progress(collection_name=collection_name)

{'total_rows': 8676, 'indexed_rows': 8676, 'pending_index_rows': 0}

In [79]:
hey_bagh_caltech_collect.num_entities

8676

### Do Vector Search

In [11]:
hey_bagh_caltech_collect = Collection(name="heybagh_caltech101_imgs")
hey_bagh_caltech_collect.load()

In [12]:
import time
# search based on vector similarity
print("Start searching based on vector similarity")

# input image path
image_path = "/caltech-101/101_ObjectCategories/brontosaurus/image_0006.jpg"

# get embeddings of the image
img_vector_to_search = [extractor.preprocess_extract_feature(image_path).tolist()]

# set search parameters
search_params = {
    "metric_type": "COSINE",
    "params": {"nprobe": 8},
}

# Search
start_time = time.time()
result = hey_bagh_caltech_collect.search(img_vector_to_search, "img_embeddings", search_params, limit=10, output_fields=["img_cls_name", "img_rel_path"])
end_time = time.time()

for hits in result:
    for hit in hits:
        print(f"hit: {hit}, Class Name: {hit.entity.get('img_cls_name')}")
        # print(hit.entity.get(''))
print("Time to search", end_time - start_time)


Start searching based on vector similarity
hit: id: 448237751945690752, distance: 0.9999998211860657, entity: {'img_rel_path': 'brontosaurus/image_0006.jpg', 'img_cls_name': 'brontosaurus'}, Class Name: brontosaurus
hit: id: 448237751945690742, distance: 0.8235660195350647, entity: {'img_rel_path': 'brontosaurus/image_0007.jpg', 'img_cls_name': 'brontosaurus'}, Class Name: brontosaurus
hit: id: 448237751945686664, distance: 0.7548072934150696, entity: {'img_rel_path': 'stegosaurus/image_0027.jpg', 'img_cls_name': 'stegosaurus'}, Class Name: stegosaurus
hit: id: 448237751945685572, distance: 0.7245089411735535, entity: {'img_rel_path': 'kangaroo/image_0034.jpg', 'img_cls_name': 'kangaroo'}, Class Name: kangaroo
hit: id: 448237751945692433, distance: 0.7190746665000916, entity: {'img_rel_path': 'elephant/image_0050.jpg', 'img_cls_name': 'elephant'}, Class Name: elephant
hit: id: 448237751945692479, distance: 0.7045187950134277, entity: {'img_rel_path': 'elephant/image_0047.jpg', 'img_cls

In [21]:
for ele in result:
    for e in ele:
        print(e.entity.get('img_rel_path'))

brontosaurus/image_0006.jpg
brontosaurus/image_0007.jpg
stegosaurus/image_0027.jpg
kangaroo/image_0034.jpg
elephant/image_0050.jpg
elephant/image_0047.jpg
rhino/image_0025.jpg
elephant/image_0060.jpg
elephant/image_0029.jpg
rhino/image_0003.jpg


In [85]:
# hybrid search
print("Start hybrid searching with img_cls_name == buddha")

start_time = time.time()
result = hey_bagh_caltech_collect.search(img_vector_to_search, "img_embeddings", search_params, limit=10, expr="img_cls_name LIKE 'brontosaurus' ", output_fields=["img_cls_name"])
end_time = time.time()

for hits in result:
    for hit in hits:
        print(f"hit: {hit}| Class Name: {hit.entity.get('img_cls_name')}")
print("Time to search", end_time - start_time)

Start hybrid searching with img_cls_name == buddha
hit: id: 448237751945690752, distance: 0.9999998211860657, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690742, distance: 0.8235660195350647, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690769, distance: 0.6470790505409241, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690747, distance: 0.6464439034461975, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690748, distance: 0.6416621208190918, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690739, distance: 0.6411221623420715, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690771, distance: 0.6107290983200073, entity: {'img_cls_name': 'brontosaurus'}| Class Name: brontosaurus
hit: id: 448237751945690746, distance: 0.6102967262268066, entity: 