In [1]:
import torch
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
from PIL import Image
from scipy.spatial import distance

import numpy as np 

### feature extractor

In [3]:

class FeatureExtractor_CNN:
    def __init__(self):
        self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.model.eval()  # Set model to evaluation mode

        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def preprocess_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0)  # Add batch dimension
        return image_tensor

    def extract_features(self, image_tensor):
        with torch.no_grad():
            features = self.model(image_tensor).squeeze(0).cpu().numpy()
        return features

    def cosine_similarity(self, features1, features2):
        similarity = 1 - distance.cosine(features1, features2)
        return similarity

    def cosine_similarity_images(self, image1_path, image2_path):
        img1_tensor = self.preprocess_image(image1_path)
        img2_tensor = self.preprocess_image(image2_path)
        
        features1 = self.extract_features(img1_tensor)
        features2 = self.extract_features(img2_tensor)

        return self.cosine_similarity(features1, features2)

    def preprocess_extract_feature(self,img_path:str) -> np.ndarray:
        img_tensor = self.preprocess_image(img_path)
        features = self.extract_features(img_tensor)
        return features




In [4]:
extractor = FeatureExtractor_CNN()
image1_path = "/home/c3po/Documents/project/learning/amar-works/datasets/caltech-101/101_ObjectCategories/buddha/image_0006.jpg"
image2_path = "/home/c3po/Documents/project/learning/amar-works/HeyBagh/data/caltech-101/101_ObjectCategories/ant/image_0020.jpg"

try:
    # similarity_score = extractor.cosine_similarity_images(image1_path, image2_path)
    # print(f"Cosine similarity between images: {similarity_score}")
    embed_val = extractor.preprocess_extract_feature(image1_path)
    print(embed_val.shape)

except Exception as e:
    print(f"Error: {e}")


(1000,)


## UI

In [None]:
import gradio as gr

def show_image_and_text(image, text):
  """Displays the uploaded image and entered text."""
  return image, text

# Define the interface with input and output components
interface = gr.Interface(
  fn=show_image_and_text,
  inputs=[
    gr.Image(type="pil"),  # Allow uploading an image
    gr.Textbox(label="Enter text:")  # Allow entering text
  ],
  outputs=[
    gr.Image(type="pil", label="Uploaded image"),  # Display the uploaded image
    gr.Textbox(label="Your text:")  # Display the entered text
  ],
  title="HeyBagh",
  description="Upload a photo of object that you wish to search",
)

# Launch the interface
interface.launch()


### Milvus Client

In [2]:
from pymilvus import connections, db

conn = connections.connect(host="127.0.0.1", port=19530)

In [None]:
# connections.disconnect("")

In [None]:
# Create Database
# database = db.create_database("heybagh_db")


In [6]:

db.list_database()

['default', 'heybagh_db']

In [7]:
# create fields schema  

from pymilvus import CollectionSchema, FieldSchema, DataType

"""
schema for vector-db
{
    img_id: INT
    img_cls_name: STR
    img_embeddings: VECTOR-FLOAT
}
"""

img_id = FieldSchema(name='img_id', dtype=DataType.INT64, description='assigned Image ID', is_primary=True)

img_cls_name = FieldSchema(name='img_cls_name', dtype=DataType.VARCHAR, description='Class-Name of the Image it belongs to', default_value='unknown', max_length=80)

img_embeddings = FieldSchema(name='img_embeddings', dtype=DataType.FLOAT_VECTOR, dim=1000, description='Extracted embeddings of Image')

In [8]:
# create colleection schema

schema =  CollectionSchema(fields=[img_id, img_cls_name, img_embeddings],
                           description='heybagh image search',
                           enable_dynamic_field=True,
                           auto_id=True)

In [10]:
from pymilvus import Collection

# Create collection
collection_name = 'heybagh_caltech101_imgs'

hey_bagh_caltech_collect = Collection(name=collection_name,
                        schema=schema,
                        shards_num=2,)

In [12]:
from pymilvus import utility
print("Check if collection exists: ", utility.has_collection(collection_name))

# collection = Collection(collection_name)

print(hey_bagh_caltech_collect.schema)

print(hey_bagh_caltech_collect.indexes)

Check if collection exists:  True
{'auto_id': True, 'description': 'heybagh image search', 'fields': [{'name': 'img_id', 'description': 'assigned Image ID', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'img_cls_name', 'description': 'Class-Name of the Image it belongs to', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 80}}, {'name': 'img_embeddings', 'description': 'Extracted embeddings of Image', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 1000}}], 'enable_dynamic_field': True}
[<pymilvus.orm.index.Index object at 0x7f75ace4d1b0>]


In [41]:
# Insert data into collection to create index
root_path_to_caltech_101_dataset = "/home/c3po/Documents/project/learning/amar-works/datasets/caltech-101/101_ObjectCategories"

import os
from tqdm import tqdm

def traverse_directories(root_dir):
    """Traverses all directories under the given root directory and yields file and subdirectory information."""
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            file_path = os.path.join(dirpath, filename)
            yield filename, file_path, os.path.basename(dirpath)  # Yield filename, full path, subdirectory name


entity_img_cls_name = []
entity_img_embedding = []
for filename, file_path, subdir_name in traverse_directories(root_path_to_caltech_101_dataset):
    print(subdir_name, filename)
    entity_img_cls_name.append(subdir_name)
    entity_img_embedding.append(extractor.preprocess_extract_feature(file_path))
entities = [entity_img_cls_name, entity_img_embedding]

print("Len img_cls", len(entities[0]))
print("len img_embeds", len(entities[1]))

Faces image_0223.jpg
Faces image_0043.jpg
Faces image_0380.jpg
Faces image_0076.jpg
Faces image_0172.jpg
Faces image_0356.jpg
Faces image_0263.jpg
Faces image_0063.jpg
Faces image_0134.jpg
Faces image_0104.jpg
Faces image_0068.jpg
Faces image_0311.jpg
Faces image_0163.jpg
Faces image_0428.jpg
Faces image_0016.jpg
Faces image_0265.jpg
Faces image_0205.jpg
Faces image_0127.jpg
Faces image_0152.jpg
Faces image_0022.jpg
Faces image_0340.jpg
Faces image_0012.jpg
Faces image_0231.jpg
Faces image_0245.jpg
Faces image_0270.jpg
Faces image_0379.jpg
Faces image_0200.jpg
Faces image_0130.jpg
Faces image_0162.jpg
Faces image_0394.jpg
Faces image_0420.jpg
Faces image_0203.jpg
Faces image_0297.jpg
Faces image_0040.jpg
Faces image_0019.jpg
Faces image_0365.jpg
Faces image_0367.jpg
Faces image_0352.jpg
Faces image_0266.jpg
Faces image_0029.jpg
Faces image_0141.jpg
Faces image_0098.jpg
Faces image_0119.jpg
Faces image_0361.jpg
Faces image_0327.jpg
Faces image_0281.jpg
Faces image_0412.jpg
Faces image_0

In [18]:
hey_bagh_caltech_collect.insert(entities)

(insert count: 8676, delete count: 0, upsert count: 0, timestamp: 448154668111233026, success count: 8676, err count: 0)

In [19]:
hey_bagh_caltech_collect.flush()

In [3]:
from pymilvus import utility

utility.list_collections()

['heybagh_caltech101_imgs']

### Creating Index

In [20]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 512},
}

In [22]:
hey_bagh_caltech_collect.drop_index()

In [23]:
hey_bagh_caltech_collect.create_index('img_embeddings', index)

Status(code=0, message=)

In [24]:
utility.index_building_progress(collection_name=collection_name)

{'total_rows': 8676, 'indexed_rows': 8676, 'pending_index_rows': 0}

In [25]:
hey_bagh_caltech_collect.num_entities

8676

### Do Vector Search

In [26]:
hey_bagh_caltech_collect.load()

In [53]:
import time
# search based on vector similarity
print("Start searching based on vector similarity")
img_vector_to_search = [extractor.preprocess_extract_feature(image1_path).tolist()]
# img_vector_to_search = entities[-1][-2:]
search_params = {
    "metric_type": "L2",
    "params": {"nprobe": 10},
}

start_time = time.time()
result = hey_bagh_caltech_collect.search(img_vector_to_search, "img_embeddings", search_params, limit=3, output_fields=["img_cls_name"])
end_time = time.time()

for hits in result:
    for hit in hits:
        print(f"hit: {hit}, Class Name: {hit.entity.get('img_cls_name')}")
print("Time to search", end_time - start_time)

print(type(img_vector_to_search))

Start searching based on vector similarity
hit: id: 448153848572826763, distance: 0.0, entity: {'img_cls_name': 'buddha'}, Class Name: buddha
hit: id: 448153848572826791, distance: 3024.68994140625, entity: {'img_cls_name': 'buddha'}, Class Name: buddha
hit: id: 448153848572826775, distance: 3863.55078125, entity: {'img_cls_name': 'buddha'}, Class Name: buddha
Time to search 0.004317522048950195
<class 'list'>


In [55]:
# hybrid search
print("Start hybrid searching with img_cls_name == buddha")

start_time = time.time()
result = hey_bagh_caltech_collect.search(img_vector_to_search, "img_embeddings", search_params, limit=3, expr="img_cls_name LIKE 'buddha' ", output_fields=["img_cls_name"])
end_time = time.time()

for hits in result:
    for hit in hits:
        print(f"hit: {hit}| Class Name: {hit.entity.get('img_cls_name')}")
print("Time to search", end_time - start_time)

Start hybrid searching with img_cls_name == buddha
hit: id: 448153848572826763, distance: 0.0, entity: {'img_cls_name': 'buddha'}| Class Name: buddha
hit: id: 448153848572826791, distance: 3024.68994140625, entity: {'img_cls_name': 'buddha'}| Class Name: buddha
hit: id: 448153848572826775, distance: 3863.55078125, entity: {'img_cls_name': 'buddha'}| Class Name: buddha
Time to search 0.008576154708862305
