In [1]:
import numpy as np
import glob
from PIL import Image
import matplotlib.pyplot as plt
import torch

In [2]:
download_path = "datasets"  # change to your own download path
path_query = download_path + "/query_img_4186"
path_query_txt = download_path + "/query_img_box_4186"

# path_query_txt is the directory to the bounding box information of the instance(s) for the query images
path_gallery = download_path + "/gallery_4186"

name_query = glob.glob(path_query + "/*.jpg")
num_query = len(name_query)

name_box = glob.glob(path_query_txt+"/*.txt")

name_gallery = glob.glob(path_gallery + "/*.jpg")
num_gallery = len(name_gallery)
record_all = np.zeros((num_query, len(name_gallery)))

query_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_query + "/*.jpg")]
gallery_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_gallery + "/*.jpg")]

In [3]:
# transformer encoder
from transformers import AutoImageProcessor, AutoModel
import torch.nn as nn
from resnet import ResNet50


PATH = "./resnet50.pth"
resnet = ResNet50(10)
resnet.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
resnet.fc = nn.Sequential()
resnet.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (max_pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsampling): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, m

In [4]:
import timm
import torch

transformer = timm.create_model(
    'vit_small_patch14_dinov2.lvd142m',
    pretrained=True,
    num_classes=0, 
)
transformer = transformer.eval()

In [5]:
from torchvision.transforms import Resize, Compose, ToTensor
from torch.utils.data import Dataset, DataLoader


class QueryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, bounding_box_path,transform):        
        self.image_paths = image_paths
        self.transform = transform
        self.bounding_box_path = bounding_box_path

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        bounding_box = self.bounding_box_path[idx]
        x, y, w, h = np.loadtxt(bounding_box)

        # image = image.crop((x, y, x+w, y+h))        
        image = self.transform(image)
        return image

class GalleryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform):        
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)        
        image = self.transform(image)
        return image  



In [7]:
import torchvision.transforms as transforms
import torch

cnn_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

data_config = timm.data.resolve_model_data_config(transformer)
transformer_transforms = timm.data.create_transform(**data_config, is_training=False)


In [10]:
cnn_query_dataset = QueryDataset(name_query, name_box, transform=cnn_transforms)
cnn_query_dataloader = DataLoader(cnn_query_dataset, batch_size=16, num_workers=4)  

cnn_query_embeddings = []
with torch.no_grad():
    for images in cnn_query_dataloader:
        outputs = resnet(images)
        cnn_query_embeddings.append(outputs)
    

In [11]:
cnn_gallery_dataset = GalleryDataset(name_gallery, transform=transformer_transforms)
cnn_gallery_dataloader = DataLoader(cnn_gallery_dataset, batch_size=16, num_workers=4)  

cnn_gallery_embeddings = []
with torch.no_grad():
    for images in cnn_gallery_dataloader:
        outputs = resnet(images)
        cnn_gallery_embeddings.append(outputs)

In [None]:
tformer_query_dataset = QueryDataset(name_query, name_box, transform=transforms)
tformer_query_dataloader = DataLoader(tformer_query_dataset, batch_size=16, num_workers=4)  

tformer_query_embeddings = []
with torch.no_grad():
    for images in tformer_query_dataloader:
        # images = images.to(device)        
        outputs = transformer.forward_features(images)
        tformer_query_embeddings.append(outputs)
        

In [None]:
tformer_gallery_dataset = GalleryDataset(name_gallery, transform=transforms)
tformer_gallery_dataloader = DataLoader(tformer_gallery_dataset, batch_size=16, num_workers=3)  

tformer_gallery_embeddings = []
with torch.no_grad():
    for images in tformer_gallery_dataloader:
        # images = images.to(device)        
        outputs = transformer.forward_features(images)
        tformer_gallery_embeddings.append(outputs)

In [13]:
cnn_query_embeddings_combined = torch.cat(cnn_query_embeddings, dim=0)
cnn_gallery_embeddings_combined = torch.cat(cnn_gallery_embeddings, dim=0)

In [14]:
cnn_query_embeddings_combined = cnn_query_embeddings_combined.view(cnn_query_embeddings_combined.shape[0], -1)
cnn_gallery_embeddings_combined = cnn_gallery_embeddings_combined.view(cnn_gallery_embeddings_combined.shape[0], -1)

In [19]:
print(cnn_query_embeddings_combined.shape)
print(cnn_gallery_embeddings_combined.shape)

torch.Size([20, 2048])
torch.Size([4967, 2048])


In [None]:
tformer_query_embeddings_combined = torch.cat(tformer_query_embeddings, dim=0)
tformer_gallery_embeddings_combined = torch.cat(tformer_gallery_embeddings, dim=0)

In [None]:
tformer_query_embeddings_norm = F.normalize(tformer_query_embeddings_combined)
tformer_gallery_embeddings_norm = F.normalize(tformer_gallery_embeddings_combined)

In [None]:
tformer_cosine_similarities = torch.empty((tformer_query_embeddings_norm.shape[0], tformer_gallery_embeddings_norm.shape[0]))

for i in range(tformer_query_embeddings_norm.shape[0]):
    query_embedding = tformer_query_embeddings_norm[i].unsqueeze(0) 
    dot_product = torch.matmul(tformer_query_embeddings_norm, tformer_gallery_embeddings_norm.T)
    query_norm = torch.norm(query_embedding, dim=1)
    gallery_norm = torch.norm(tformer_gallery_embeddings_norm, dim=1)
    tformer_cosine_similarities[i] = dot_product / (query_norm * gallery_norm)

In [None]:
cnn_weight = 0.4
tformer_weight = 1 - cnn_weight
weighted_similarities = tformer_weight*tformer_cosine_similarities + cnn_weight*cnn_cosine_similarities

In [None]:
sorted_similarities, sorted_indices = torch.sort(weighted_similarities, dim=1, descending=True)

In [None]:
f = open("ensemble_rank_list.txt", "w")
for i in range(len(sorted_indices)):
    f.write("Q" + str(i + 1) + ": ")    
    f.write(" ".join([str(x.item()) for x in sorted_indices[i]]))
    f.write("\n")
f.close()

In [None]:
import os
import glob
import numpy as np
import shutil


top_10_indices = sorted_indices[i, :10].tolist() # Get the top 10 indices

download_path = "datasets"
path_gallery = os.path.join(download_path, "gallery_4186")

# Assuming query_imgs_no contains the names of the query images without the extension
for query_img_no in query_imgs_no:
    # Create a folder for the query image
    query_img_folder = os.path.join(download_path, f"image_{query_img_no}")
    os.makedirs(query_img_folder, exist_ok=True)
    
    # Get the top 10 indices for this query image
    top_10_indices = sorted_indices[query_imgs_no.index(query_img_no), :10].tolist()
    
    # Copy the top 10 most similar gallery images to the query image's folder
    for index in top_10_indices:
        gallery_img_name = gallery_imgs_no[index] + ".jpg" # Assuming gallery images are in .jpg format
        src_path = os.path.join(path_gallery, gallery_img_name)
        dst_path = os.path.join(query_img_folder, gallery_img_name)
        shutil.copy(src_path, dst_path)
