In [1]:
import numpy as np
import glob
from PIL import Image
import matplotlib.pyplot as plt
import torch

In [2]:
download_path = "datasets"  # change to your own download path
path_query = download_path + "/query_img_4186"
path_query_txt = download_path + "/query_img_box_4186"

# path_query_txt is the directory to the bounding box information of the instance(s) for the query images
path_gallery = download_path + "/gallery_4186"

name_query = glob.glob(path_query + "/*.jpg")
num_query = len(name_query)

name_box = glob.glob(path_query_txt+"/*.txt")

name_gallery = glob.glob(path_gallery + "/*.jpg")
num_gallery = len(name_gallery)
record_all = np.zeros((num_query, len(name_gallery)))

query_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_query + "/*.jpg")]
gallery_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_gallery + "/*.jpg")]

In [None]:
# transformer encoder
import torch.nn as nn
from resnet import ResNet50


PATH = "./resnet50.pth"
resnet = ResNet50(10)
resnet.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
resnet.fc = nn.Sequential()
resnet.eval()


In [None]:
from torchvision.transforms import Resize, Compose, ToTensor
from torch.utils.data import Dataset, DataLoader


class QueryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, bounding_box_path,transform):        
        self.image_paths = image_paths
        self.transform = transform
        self.bounding_box_path = bounding_box_path

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        bounding_box = self.bounding_box_path[idx]
        x, y, w, h = np.loadtxt(bounding_box)

        # image = image.crop((x, y, x+w, y+h))        
        image = self.transform(image)
        return image

class GalleryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform):        
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)        
        image = self.transform(image)
        return image  



In [None]:
cnn_query_dataset = QueryDataset(name_query, name_box, transform=cnn_transforms)
cnn_query_dataloader = DataLoader(cnn_query_dataset, batch_size=16, num_workers=4)  

cnn_query_embeddings = []
with torch.no_grad():
    for images in cnn_query_dataloader:
        outputs = resnet(images)
        cnn_query_embeddings.append(outputs)
    

In [None]:
cnn_gallery_dataset = GalleryDataset(name_gallery, transform=transformer_transforms)
cnn_gallery_dataloader = DataLoader(cnn_gallery_dataset, batch_size=16, num_workers=4)  

cnn_gallery_embeddings = []
with torch.no_grad():
    for images in cnn_gallery_dataloader:
        outputs = resnet(images)
        cnn_gallery_embeddings.append(outputs)

In [None]:
cnn_query_embeddings_combined = torch.cat(cnn_query_embeddings, dim=0)
cnn_gallery_embeddings_combined = torch.cat(cnn_gallery_embeddings, dim=0)

In [None]:
cnn_query_embeddings_combined = cnn_query_embeddings_combined.view(cnn_query_embeddings_combined.shape[0], -1)
cnn_gallery_embeddings_combined = cnn_gallery_embeddings_combined.view(cnn_gallery_embeddings_combined.shape[0], -1)

In [None]:
import torch.nn.functional as F

cnn_query_embeddings_norm = F.normalize(cnn_query_embeddings_combined)
cnn_gallery_embeddings_norm = F.normalize(cnn_gallery_embeddings_combined)

cnn_cosine_similarities = F.cosine_similarity(cnn_query_embeddings_norm.unsqueeze(1), cnn_gallery_embeddings_norm.unsqueeze(0), dim=2)

In [None]:
sorted_similarities, sorted_indices = torch.sort(cnn_cosine_similarities, dim=1, descending=True)

In [None]:
f = open("cnn_rank_list.txt", "w")
for i in range(len(sorted_indices)):
    f.write("Q" + str(i + 1) + ": ")    
    f.write(" ".join([str(x.item()) for x in sorted_indices[i]]))
    f.write("\n")
f.close()

In [None]:
f = open("cnn_top10.txt", "w")
for i in range(len(sorted_indices)):
    top_10_indices = sorted_indices[i, :10] # Slice to get only the top 10 indices
    f.write("Q" + str(i + 1) + ": ")    
    f.write(" ".join([str(x.item()) for x in top_10_indices]))
    f.write("\n")
f.close()