##### `Image Search with Pinecone and ConvBase for Feature Extraction`

In [15]:
## Import Libraries
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from dotenv import load_dotenv
import timm    ## PyTorch Image Models (timm)
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

from qdrant_client import QdrantClient
from qdrant_client.http.models import VectorParams, Distance, Batch, PointIdsList

In [2]:
## Load dotenv file
_ = load_dotenv(override=True)
qdrant_key = os.getenv('QDRANT_API_KEY')
qdrant_url = os.getenv('QDRANT_URL')

In [3]:
## Get the images paths
imgaes_paths = [os.path.join('dataset-images', img) for img in os.listdir('dataset-images')]

## Create a DF contains the images_paths and create a random ID
df = pd.DataFrame({'paths': imgaes_paths})
df['id'] = np.arange(3054, 3054+len(df), 1)

## Take only the first 500 images --> for simplicity
df_use = df.iloc[:500]
df_use

Unnamed: 0,paths,id
0,dataset-images\0009fc27d9.jpg,3054
1,dataset-images\0014c2d720.jpg,3055
2,dataset-images\00196e8fac.jpg,3056
3,dataset-images\001fc748e6.jpg,3057
4,dataset-images\002bb8e03b.jpg,3058
...,...,...
495,dataset-images\16af889d9d.jpg,3549
496,dataset-images\16b44ef03b.jpg,3550
497,dataset-images\16b501e949.jpg,3551
498,dataset-images\16bbc4b4dc.jpg,3552


* `ConvBase for Feature Extraction`

In [4]:
## Here, I will use VGG19 Model ConvBase using timm library as a convbase for feature extraction
## The VGG19 Model after flattening the vector it will be of lenght 4096.

model = timm.create_model('vgg19', pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])
_ = model.eval()

In [5]:
def extract_images_features(images_paths: list):
    ''' This Function is taking a list of images_paths and returns the features extraction from them using VGG19 Model.
    '''

    ## Transformation before extraction
    transform = transforms.Compose([   
                            ## VGG required images (224, 224)
                            transforms.Resize((224, 224)),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                   ])
    
    # Looping over the images_paths
    batch_features = []
    for image_path in images_paths:
        ## Convert it to Pillow and then to tensor
        image_tensor = Image.open(image_path).convert('RGB')
        image_tensor = transform(image_tensor).unsqueeze(0)

        ## Pass the Image and get the Feature Extraction
        with torch.no_grad():
            conv_features = model(image_tensor)
            ## Flatten --> I want a vector as a list in 1D
            image_features = conv_features.view(conv_features.size(0), -1).tolist()[0]

        ## Append to the list
        batch_features.append(image_features)

    return batch_features

## Test the above function
vgg19_vect_length = len(extract_images_features(images_paths=[r'dataset-images\0009fc27d9.jpg'])[0])
print(f'Vector Lenght using VGG19 Model is: {vgg19_vect_length}')

Vector Lenght using VGG19 Model is: 4096


* `Upserting to Qdrant`

In [10]:
## Connect to Qdrant Client

## Initilaize a Client
client = QdrantClient(url=qdrant_url, api_key=qdrant_key)

## Collection Configurations
collection_config = VectorParams(
                            size=4096,                    ## The lenght of vgg19 convabase model
                            distance=Distance.COSINE,     ## The similarity metric
                            on_disk=True                  ## RAM optimizing
                                )

## Create a Collection 
client.recreate_collection(collection_name='image-search-course', vectors_config=collection_config)

True

In [11]:
## Check Status of Collection
collection_status = client.get_collection(collection_name='image-search-course').status
collection_count_vectors = client.get_collection(collection_name='image-search-course').vectors_count

print(f'Status is: {collection_status}')
print(f'Vectors Count is: {collection_count_vectors}')

Status is: green
Vectors Count is: 0


In [12]:
## Function for upserting data to Qdrant

def upsert_to_qdrant(df, batch_size=32):

    ## A list for failed_ids
    failed_ids = []

    for batch_start in tqdm(range(0, len(df), batch_size)):

        try:
            ## Prepare batches
            batch_end = min(batch_start+batch_size, len(df))
            paths_batch = df['paths'][batch_start: batch_end].tolist()
            ids_batch = df['id'][batch_start: batch_end].tolist()     ## No need to be converted to string (Qdrant need integer)

            ## Extract Features by calling the function
            batch_features = extract_images_features(images_paths=paths_batch)

            ## Prepare to Qdrant
            to_upsert = Batch(ids=ids_batch, vectors=batch_features)

            ## Upsert to Qdrant
            client.upsert(collection_name='image-search-course', wait=True, points=to_upsert)


        except Exception as e:
            print(f'Error in upserting: {e}')
            failed_ids.append(ids_batch)

    return failed_ids


## Apply the function
failed_ids = upsert_to_qdrant(df=df_use, batch_size=32)

100%|██████████| 16/16 [05:39<00:00, 21.24s/it]


In [13]:
## Check Status of Collection after upserting
collection_status = client.get_collection(collection_name='image-search-course').status
collection_count_vectors = client.get_collection(collection_name='image-search-course').vectors_count

print(f'Status is: {collection_status}')
print(f'Vectors Count is: {collection_count_vectors}')

Status is: green
Vectors Count is: 500


In [14]:
## Inference in real-time
image_new_path = df['paths'].iloc[-1]
image_feats_new = extract_images_features(images_paths=[image_new_path])[0]

client.search(collection_name='image-search-course', query_vector=image_feats_new, limit=100, score_threshold=0.4)

[ScoredPoint(id=3263, version=6, score=0.7062105, payload={}, vector=None),
 ScoredPoint(id=3439, version=12, score=0.55330193, payload={}, vector=None),
 ScoredPoint(id=3080, version=0, score=0.4998963, payload={}, vector=None),
 ScoredPoint(id=3491, version=13, score=0.4963698, payload={}, vector=None),
 ScoredPoint(id=3124, version=2, score=0.46910954, payload={}, vector=None),
 ScoredPoint(id=3167, version=3, score=0.4662863, payload={}, vector=None),
 ScoredPoint(id=3204, version=4, score=0.41910368, payload={}, vector=None),
 ScoredPoint(id=3527, version=14, score=0.41092783, payload={}, vector=None),
 ScoredPoint(id=3177, version=3, score=0.4023589, payload={}, vector=None)]

In [16]:
## delete example using id
# client.delete(collection_name='image-search-course', points_selector=PointIdsList(points=[3054]))

UpdateResult(operation_id=16, status=<UpdateStatus.COMPLETED: 'completed'>)