In [50]:
import os
import glob
import torch
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
import torch.nn.functional as F
import uuid
from tqdm import tqdm

import yaml
from oml.registry.models import get_extractor_by_cfg
from shared import resize_and_pad_image_cv2

from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, SearchRequest

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
transform, _ = get_transforms_for_pretrained("vitb14_dinov2")
model_path = "./weights/birds_200_vitb14_dinov2/checkpoints/best.ckpt"

extractor = get_extractor_by_cfg({
    'name': 'vit',
    'args': {
        'normalise_features': True,
        'use_multi_scale': False,
        'weights': model_path,
        'arch': 'vitb14'
    }
}).to(DEVICE)

def get_embedding(image_path):
    with torch.no_grad():
        image = Image.open(image_path).convert("RGB")
        image = Image.fromarray(resize_and_pad_image_cv2(np.array(image)))
        image = transform(image).unsqueeze(0)
        embeddings = extractor(image.to(DEVICE))
        return embeddings
    
embedding1 = get_embedding("./downloads/birds-200-species/CUB_200_2011/train/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg")
embedding2 = get_embedding("./downloads/birds-200-species/CUB_200_2011/train/001.Black_footed_Albatross/Black_Footed_Albatross_0058_796074.jpg")
embedding3 = get_embedding("./downloads/birds-200-species/CUB_200_2011/train/078.Gray_Kingbird/Gray_Kingbird_0051_70139.jpg")

F.cosine_similarity(embedding1, embedding2, dim=1).item(), F.cosine_similarity(embedding1, embedding3, dim=1).item()

Prefix <model.model.> was removed from the state dict.


(0.803785502910614, 0.060847196727991104)

In [46]:
COLLECTION_NAME = "image-embedding-training"
EMBEDDING_SIZE = embedding1.shape[1]

# Initialize qdrant client
client = QdrantClient(
    host="image_embeddings_qdrant",
    port=6333,
    grpc_port=6334,
    prefer_grpc=True
)

In [47]:
train_images = glob.glob("./downloads/birds-200-species/CUB_200_2011/train/*/*")
val_images = glob.glob("./downloads/birds-200-species/CUB_200_2011/val/*/*")

In [51]:
# check if collection exists
existing_collections = [c.name for c in client.get_collections().collections]

# re-create collection
if COLLECTION_NAME in existing_collections:
    client.delete_collection(collection_name=COLLECTION_NAME)
    existing_collections = [c.name for c in client.get_collections().collections]

# create collection if doesn't exists
if not (COLLECTION_NAME in existing_collections):
    client.recreate_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=EMBEDDING_SIZE, distance=Distance.COSINE),
    )

total_records_qdrant = client.http.collections_api.get_collection(
    COLLECTION_NAME
).dict()["result"]["points_count"] or 0

# print(f"Total records inside Qdrant: {total_records_qdrant}")

def index_data(batch):
    with torch.no_grad():
        images = torch.stack([
            transform(Image.fromarray(resize_and_pad_image_cv2(np.array(Image.open(x[0]).convert("RGB"))))) for x in batch
        ])
        embeddings = extractor(images.to(DEVICE)).cpu().numpy()
        labels = [x[1] for x in batch]

    points = []
    for embedding, label in zip(embeddings, labels):
        embedding = embedding.tolist()
        payload_id = uuid.uuid1().int >> 64
        points.append(
            PointStruct(
                id=payload_id,
                payload={"label": label},
                vector=embedding,
            )
        )
    
    client.upsert(collection_name=COLLECTION_NAME, points=points, wait=True)

batch_size = 128
batch = []

for image in tqdm(train_images):
    label = os.path.basename(os.path.dirname(image))

    if len(batch) >= batch_size:
        index_data(batch)
        batch = []

    batch.append([image, label])

if len(batch):
    index_data(batch)
    batch = []
    
def search_data(batch):
    with torch.no_grad():
        images = torch.stack([
            transform(Image.fromarray(resize_and_pad_image_cv2(np.array(Image.open(x[0]).convert("RGB"))))) for x in batch
        ])
        embeddings = extractor(images.to(DEVICE)).cpu().numpy()
        labels = [x[1] for x in batch]
    
    search_queries = [
        SearchRequest(
            vector=embedding,
            filter=None,
            limit=1,
            with_payload=True,
        ) for embedding in embeddings
    ]
    
    # Search for matching embedding vector and it's label in qdrant db
    res = client.search_batch(collection_name=COLLECTION_NAME, requests=search_queries)
    predicted = [r[0].payload["label"] for r in res]
    
    return predicted

batch_size = 128
batch = []

correct, incorrect = 0, 0
for image in tqdm(val_images):
    label = os.path.basename(os.path.dirname(image))

    if len(batch) >= batch_size:
        for b, predicted in zip(batch, search_data(batch)):
            image, label = b[0], b[1]
            if label == predicted:
                correct += 1
            else:
                incorrect += 1

        batch = []

    batch.append([image, label])

if len(batch):
    for b, predicted in zip(batch, search_data(batch)):
        image, label = b[0], b[1]
        if label == predicted:
            correct += 1
        else:
            incorrect += 1

    batch = []
    
acc = correct / (correct + incorrect)
print(f"Accuracy: {acc}")

  client.recreate_collection(
/tmp/ipykernel_143896/1035688666.py:18: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  ).dict()["result"]["points_count"] or 0
100%|██████████| 9977/9977 [00:46<00:00, 215.15it/s]
100%|██████████| 1811/1811 [00:08<00:00, 208.00it/s]

Accuracy: 0.864715626725566



