In [None]:
import numpy as np
import glob
from PIL import Image
import matplotlib.pyplot as plt
import torch

In [54]:
download_path = "datasets"  # change to your own download path
path_query = download_path + "/query_img_4186"
path_query_txt = download_path + "/query_img_box_4186"

# path_query_txt is the directory to the bounding box information of the instance(s) for the query images
path_gallery = download_path + "/gallery_4186"

name_query = glob.glob(path_query + "/*.jpg")
num_query = len(name_query)

name_box = glob.glob(path_query_txt+"/*.txt")

name_gallery = glob.glob(path_gallery + "/*.jpg")
num_gallery = len(name_gallery)
record_all = np.zeros((num_query, len(name_gallery)))

query_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_query + "/*.jpg")]
gallery_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_gallery + "/*.jpg")]

In [None]:
# transformer encoder
from transformers import AutoImageProcessor, AutoModel
import torch.nn as nn
from resnet import ResNet50


PATH = "./resnet50.pth"
resnet = ResNet50(10)
resnet.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
resnet.fc = nn.Sequential()

resnet.eval()


In [None]:
import timm
import torch


sam = timm.create_model(
    'samvit_large_patch16.sa1b',
    pretrained=True,
    num_classes=0, 
)
sam = sam.eval()

In [None]:
from torchvision.transforms import Resize, Compose, ToTensor
from torch.utils.data import Dataset, DataLoader


class QueryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, bounding_box_path,transform):        
        self.image_paths = image_paths
        self.transform = transform
        self.bounding_box_path = bounding_box_path

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        bounding_box = self.bounding_box_path[idx]
        x, y, w, h = np.loadtxt(bounding_box)

        # image = image.crop((x, y, x+w, y+h))        
        image = self.transform(image)
        return image

class GalleryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform):        
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)        
        image = self.transform(image)
        return image  



In [None]:
import torchvision.transforms as transforms
import torch

cnn_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

data_config = timm.data.resolve_model_data_config(sam)
sam_transforms = timm.data.create_transform(**data_config, is_training=False)


In [None]:
cnn_query_dataset = QueryDataset(name_query, name_box, transform=cnn_transforms)
cnn_query_dataloader = DataLoader(cnn_query_dataset, batch_size=20, num_workers=32)  

cnn_query_embeddings = []
with torch.no_grad():
    for images in cnn_query_dataloader:
        outputs = resnet(images)
        cnn_query_embeddings.append(outputs)
    

In [None]:
cnn_gallery_dataset = GalleryDataset(name_gallery, transform=cnn_transforms)
cnn_gallery_dataloader = DataLoader(cnn_gallery_dataset, batch_size=2048, num_workers=32)  

cnn_gallery_embeddings = []
with torch.no_grad():
    for images in cnn_gallery_dataloader:
        outputs = resnet(images)
        cnn_gallery_embeddings.append(outputs)

In [None]:
tformer_query_dataset = QueryDataset(name_query, name_box, transform=sam_transforms)
tformer_query_dataloader = DataLoader(tformer_query_dataset, batch_size=20, num_workers=32)  

tformer_query_embeddings = []
with torch.no_grad():
    for images in tformer_query_dataloader:
        # images = images.to(device)        
        outputs = sam.forward_features(images)
        tformer_query_embeddings.append(outputs)
        

In [None]:
from tqdm import tqdm
from tqdm.notebook import tqdm

In [None]:
tformer_gallery_dataset = GalleryDataset(name_gallery, transform=sam_transforms)
tformer_gallery_dataloader = DataLoader(tformer_gallery_dataset, batch_size=512, num_workers=32)

tformer_gallery_embeddings = []
with torch.no_grad():
    for images in tqdm(tformer_gallery_dataloader, desc="progress"):
        outputs = sam.forward_features(images)
        tformer_gallery_embeddings.append(outputs)

In [None]:
torch.save(tformer_gallery_embeddings, "gallery_embeddings_SAM.pt")

In [None]:
cnn_query_embeddings_combined = torch.cat(cnn_query_embeddings, dim=0)
cnn_gallery_embeddings_combined = torch.cat(cnn_gallery_embeddings, dim=0)

In [None]:
cnn_query_embeddings_combined = cnn_query_embeddings_combined.view(cnn_query_embeddings_combined.shape[0], -1)
cnn_gallery_embeddings_combined = cnn_gallery_embeddings_combined.view(cnn_gallery_embeddings_combined.shape[0], -1)

In [None]:
print(cnn_query_embeddings_combined.shape)
print(cnn_gallery_embeddings_combined.shape)

In [None]:
tformer_query_embeddings_combined = torch.cat(tformer_query_embeddings, dim=0)
tformer_gallery_embeddings_combined = torch.cat(tformer_gallery_embeddings, dim=0)

In [None]:
print(tformer_query_embeddings_combined.shape)
print(tformer_gallery_embeddings_combined.shape)

In [None]:

tformer_query_embeddings_combined = tformer_query_embeddings_combined.view(tformer_query_embeddings_combined.shape[0], -1)
tformer_gallery_embeddings_combined = tformer_gallery_embeddings_combined.view(tformer_gallery_embeddings_combined.shape[0], -1)

In [None]:
import torch.nn.functional as F

tformer_query_embeddings_norm = F.normalize(tformer_query_embeddings_combined)
tformer_gallery_embeddings_norm = F.normalize(tformer_gallery_embeddings_combined)

In [None]:
print(tformer_query_embeddings_norm.shape)
print(tformer_gallery_embeddings_norm.shape)

In [None]:

dot_product_matrix = torch.matmul(tformer_query_embeddings_norm, tformer_gallery_embeddings_norm.T)

query_norms = torch.norm(tformer_query_embeddings_norm, dim=1, keepdim=True)
gallery_norms = torch.norm(tformer_gallery_embeddings_norm, dim=1, keepdim=True)

sam_cosine_similarities = dot_product_matrix / (query_norms * gallery_norms.T)



In [None]:
print(sam_cosine_similarities.shape)

In [None]:
clip_cosine_similarities = torch.load("cosine_similarity_SigLIP.pt")

In [None]:
cnn_weight = 0.1
sam_weight = 0.5
clip_weight = 0.4

weighted_similarities = tformer_weight*tformer_cosine_similarities + cnn_weight*cnn_cosine_similarities+clip_weight*clip_cosine_similarities


In [None]:
sorted_similarities, sorted_indices = torch.sort(weighted_similarities, dim=1, descending=True)

In [None]:
f = open("ensemble_rank_list.txt", "w")
for i in range(len(sorted_indices)):
    f.write("Q" + str(i + 1) + ": ")    
    f.write(" ".join([str(x.item()) for x in sorted_indices[i]]))
    f.write("\n")
f.close()

In [48]:
import os
import glob
import numpy as np
import shutil

            
download_path = "datasets"
path_gallery = os.path.join(download_path, "gallery_4186")

for query_img_no in query_imgs_no:
    # Create a folder for the query image
    query_img_folder = os.path.join(download_path, f"image_{query_img_no}")
    os.makedirs(query_img_folder, exist_ok=True)
    
    # Get the top 10 indices for this query image
    top_10_indices = sorted_indices[query_imgs_no.index(query_img_no), :10].tolist()
    
    # Copy the top 10 most similar gallery images to the query image's folder
    for index in top_10_indices:
        gallery_img_name = gallery_imgs_no[index] + ".jpg" # Assuming gallery images are in .jpg format
        src_path = os.path.join(path_gallery, gallery_img_name)
        dst_path = os.path.join(query_img_folder, gallery_img_name)
        shutil.copy(src_path, dst_path)


ValueError: I/O operation on closed file.

In [55]:
query_imgs_no

['1258',
 '1656',
 '1709',
 '2032',
 '2040',
 '2176',
 '2461',
 '27',
 '2714',
 '316',
 '35',
 '3502',
 '3557',
 '3833',
 '3906',
 '4354',
 '4445',
 '4716',
 '4929',
 '776']

In [58]:
with open("ensemble_top10.txt", 'w') as f:
    for i in range(len(query_imgs_no)):
        top_10_indices = sorted_indices[i, :10].tolist()
        f.write(f"Q{i+1}: {str(top_10_indices)}\n")