In [1]:
%config Application.log_level = 'INFO'

In [27]:
import os
import sys
sys.path.append('..')
sys.path.append('../src')

import numpy as np
import pandas as pd
import torch
import logging

from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import decode_image

from src.image_encoder import PreTrainedImageEncoder
from src.image_metadata_process import extract_image_metadata
from src.image_process import load_image_and_metadata
from src.qdrant_vector_db import QdrantVectorDB

from qdrant_client.models import Distance

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# 1. Load Image Encoder

In [3]:
# cache weights to a project folder
os.environ['TORCH_HOME'] = '../cache'
# select device
device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

# Init model with pre-trained weights
pre_trained_weights = ResNet50_Weights.IMAGENET1K_V2
model = resnet50(weights=pre_trained_weights)

# Create image encoder from pre-trained model
image_encoder = PreTrainedImageEncoder(model, device=device)

# Set encoder to eval mode
image_encoder.eval()

# Init transforms
image_process = pre_trained_weights.transforms()
image_process

ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

# 2. Create Vector Database

In [36]:
# Db configs
db_name = 'image_db'
dim = 2048
distance = Distance.COSINE
db_path = '../vector_database'
db_client = QdrantVectorDB.init_client(db_path)

In [37]:
# If creating a db:
qdrant_db = QdrantVectorDB.create_db(
    database_name=db_name, 
    dimension=dim, 
    distance=distance, 
    database_path=db_path,
    client=db_client,
)

# If loading an existing db:
# qdrant_db = QdrantVectorDB.load_db(db_name, db_path)

# If removing an existing db:
# QdrantVectorDB.remove_db(database_name=db_name, database_path=db_path, client=qdrant_db.client)

INFO:root:Created Qdrant collection 'image_db' with dimension 2048 and distance 'Cosine'.


In [39]:
# load image info: Generated in Notebook `1. model_dataset_exploration`
df_image_info = pd.read_csv('../data/validation/file_pair_id.csv')
df_image_info.head(3)

Unnamed: 0,filename,pair_id,img_source
0,000001.json,1,user
1,000002.json,1,user
2,000003.json,3,user


In [40]:
# Select all images from shop
img_info_sel = df_image_info[df_image_info['img_source']=='shop']
img_info_sel.head(3)

Unnamed: 0,filename,pair_id,img_source
10844,010845.json,1,shop
10845,010846.json,1,shop
10846,010847.json,1,shop


In [41]:
# Create input to vector db
img_dir = '../data/validation/image'
metadata_dir = '../data/validation/annos'

img_paths = []
metadata_paths = []
for i, row in img_info_sel.iterrows():
    img_id = row['filename'].split('.')[0]
    img_paths.append(os.path.join(img_dir, f"{img_id}.jpg"))
    metadata_paths.append(os.path.join(metadata_dir, f"{img_id}.json"))

In [42]:
qdrant_db.add_images_batch(
    image_paths=img_paths, 
    metadata_paths=metadata_paths, 
    image_encoder=image_encoder, 
    image_transforms=image_process,
    batch_size=512,
    parallel=16,
    device=device
)

INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant collection 'image_db'.
INFO:root:Added 512 images to Qdrant col

In [43]:
qdrant_db.get_number_of_vectors()

21309