In [None]:
from src.utils.colab_setup import setup_environment
setup_environment()

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import recall_score
import os

from models.retrieval_model import RetrievalModel
from dataloaders.deepfashion_retrievals import DeepFashionRetrieval


img_dir = "/content/data/retrieval/img"
anno_path = "/content/data/retrieval/list_eval_partition.txt"
model_path = "/content/fashion-multitask-model/retrieval_model_best.pth"

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


query_dataset = DeepFashionRetrieval(
    img_dir=img_dir,
    anno_path=anno_path,
    split="query",
    transform=transform
)

gallery_dataset = DeepFashionRetrieval(
    img_dir=img_dir,
    anno_path=anno_path,
    split="gallery",
    transform=transform
)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RetrievalModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


def get_embeddings(dataset):
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
    embeddings = []
    item_ids = []
    image_paths = []

    with torch.no_grad():
        for imgs, ids, paths in dataloader:
            imgs = imgs.to(device)
            embed = model(imgs)
            embeddings.append(embed.cpu())
            item_ids.extend(ids)
            image_paths.extend(paths)

    return torch.cat(embeddings), item_ids, image_paths


print("Generating embeddings...")
query_embeds, query_ids, query_paths = get_embeddings(query_dataset)
gallery_embeds, gallery_ids, gallery_paths = get_embeddings(gallery_dataset)


def compute_recall_at_k(query_embeds, query_ids, gallery_embeds, gallery_ids, k=1):
    correct = 0
    total = len(query_ids)

    for i, query_vec in enumerate(query_embeds):
        sims = F.cosine_similarity(query_vec.unsqueeze(0), gallery_embeds)
        topk_indices = torch.topk(sims, k=k).indices
        topk_ids = [gallery_ids[j] for j in topk_indices]

        if query_ids[i] in topk_ids:
            correct += 1

    return correct / total

recall_1 = compute_recall_at_k(query_embeds, query_ids, gallery_embeds, gallery_ids, k=1)
recall_5 = compute_recall_at_k(query_embeds, query_ids, gallery_embeds, gallery_ids, k=5)
recall_10 = compute_recall_at_k(query_embeds, query_ids, gallery_embeds, gallery_ids, k=10)

print(f"\n Recall@1:  {recall_1*100:.2f}%")
print(f" Recall@5:  {recall_5*100:.2f}%")
print(f" Recall@10: {recall_10*100:.2f}%")

In [None]:
def show_top_k(query_index, k=5):
    query_vec = query_embeds[query_index]
    sims = F.cosine_similarity(query_vec.unsqueeze(0), gallery_embeds)
    topk_indices = torch.topk(sims, k=k).indices

    fig, axs = plt.subplots(1, k+1, figsize=(15, 4))

    # Query image
    query_img = plt.imread(os.path.join(img_dir, query_paths[query_index]))
    axs[0].imshow(query_img)
    axs[0].set_title("Query")
    axs[0].axis("off")

    # Top-k matches
    for i, idx in enumerate(topk_indices):
        img_path = os.path.join(img_dir, gallery_paths[idx])
        img = plt.imread(img_path)
        axs[i+1].imshow(img)
        axs[i+1].set_title(f"Match {i+1}")
        axs[i+1].axis("off")

    plt.tight_layout()
    plt.show()

import random
for _ in range(3):
    idx = random.randint(0, len(query_dataset)-1)
    print(f"\n Showing retrieval for Query #{idx} (Item ID: {query_ids[idx]})")
    show_top_k(idx, k=5)