In [1]:
import numpy as np
import glob
from PIL import Image
import matplotlib.pyplot as plt
import torch

In [2]:
download_path = "datasets"  # change to your own download path
path_query = download_path + "/query_img_4186"
path_query_txt = download_path + "/query_img_box_4186"

# path_query_txt is the directory to the bounding box information of the instance(s) for the query images
path_gallery = download_path + "/gallery_4186"

name_query = glob.glob(path_query + "/*.jpg")
num_query = len(name_query)

name_box = glob.glob(path_query_txt+"/*.txt")

name_gallery = glob.glob(path_gallery + "/*.jpg")
num_gallery = len(name_gallery)
record_all = np.zeros((num_query, len(name_gallery)))

query_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_query + "/*.jpg")]
gallery_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_gallery + "/*.jpg")]

In [3]:
from transformers import AutoImageProcessor, AutoModel

# processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
# model = AutoModel.from_pretrained("facebook/dinov2-small")
device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)

# total_params = sum(p.numel() for p in model.parameters())
# print(total_params)

In [4]:
import timm
import torch


model = timm.create_model(
    'vit_so400m_patch14_siglip_384',
    pretrained=True,
    num_classes=0, 
)
model = model.eval()

In [5]:
from torchvision.transforms import Resize, Compose, ToTensor
from torch.utils.data import Dataset, DataLoader


class QueryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, bounding_box_path,transform):        
        self.image_paths = image_paths
        self.transform = transform
        self.bounding_box_path = bounding_box_path

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        bounding_box = self.bounding_box_path[idx]
        x, y, w, h = np.loadtxt(bounding_box)

        # image = image.crop((x, y, x+w, y+h))        
        image = self.transform(image)
        return image

class GalleryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform):        
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)        
        image = self.transform(image)
        return image  



In [7]:

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

In [8]:
train_dataset = QueryDataset(name_query, name_box, transform=transforms)
train_data_loader = DataLoader(train_dataset, batch_size=20, num_workers=32)  

query_embeddings = []
with torch.no_grad():
    for images in train_data_loader:
        images = images.to(device)
        outputs = model.forward_features(images)
        query_embeddings.append(outputs)
        torch.cuda.empty_cache()
        

In [9]:
print(query_embeddings[0].shape)

torch.Size([20, 729, 1152])


In [10]:
from tqdm import tqdm
from tqdm.notebook import tqdm

In [13]:
import gc
test_dataset = GalleryDataset(name_gallery, transform=transforms)
test_data_loader = DataLoader(test_dataset, batch_size=2048, num_workers=32)  

gallery_embeddings = []
with torch.no_grad():
    for images in tqdm(test_data_loader, desc="Processing Images"):
        # images = images.to(device)
        outputs = model.forward_features(images)
        gallery_embeddings.append(outputs)
                
        del images, outputs
        gc.collect()


Processing Images:   0%|          | 0/3 [00:00<?, ?it/s]

: 

In [None]:
torch.save(gallery_embeddings, "gallery_embeddings_SigLIP.pt")

In [None]:
query_embeddings_combined = torch.cat(query_embeddings, dim=0)

gallery_embeddings_combined = torch.cat(gallery_embeddings, dim=0)

In [None]:
print(query_embeddings_combined.shape)
print(gallery_embeddings_combined.shape)

torch.Size([20, 257, 384])
torch.Size([4967, 257, 384])


In [None]:
import torch.nn.functional as F

query_embeddings_combined = query_embeddings_combined.view(query_embeddings_combined.shape[0], -1)
gallery_embeddings_combined = gallery_embeddings_combined.view(gallery_embeddings_combined.shape[0], -1)

print(query_embeddings_combined.shape)
print(gallery_embeddings_combined.shape)

torch.Size([20, 98688])
torch.Size([4967, 98688])


In [None]:
query_embeddings_norm = F.normalize(query_embeddings_combined)
gallery_embeddings_norm = F.normalize(gallery_embeddings_combined)

In [None]:
print(query_embeddings_norm)

tensor([[ 0.0012,  0.0025,  0.0022,  ...,  0.0008, -0.0023, -0.0037],
        [-0.0017,  0.0035,  0.0016,  ..., -0.0002, -0.0020, -0.0051],
        [ 0.0004,  0.0034,  0.0014,  ...,  0.0004, -0.0019, -0.0036],
        ...,
        [ 0.0015,  0.0017,  0.0025,  ...,  0.0017, -0.0019, -0.0059],
        [ 0.0010,  0.0021,  0.0027,  ...,  0.0006, -0.0026, -0.0049],
        [ 0.0017,  0.0023,  0.0020,  ...,  0.0005, -0.0010, -0.0054]])


In [None]:
cosine_similarities = torch.empty((query_embeddings_norm.shape[0], gallery_embeddings_norm.shape[0]))

for i in range(query_embeddings_norm.shape[0]):
    query_embedding = query_embeddings_norm[i].unsqueeze(0) # Shape [1, embedding_dim]    
    dot_product = torch.matmul(query_embedding, gallery_embeddings_norm.T)
    query_norm = torch.norm(query_embedding, dim=1)
    gallery_norm = torch.norm(gallery_embeddings_norm, dim=1)
    cosine_similarities[i] = dot_product / (query_norm * gallery_norm)



In [None]:
torch.save(cosine_similarities, "cosine_similarity_SigLIP.pt")

In [None]:
sorted_similarities, sorted_indices = torch.sort(cosine_similarities, dim=1, descending=True)

In [None]:
print(sorted_indices.shape)

for i in range(len(sorted_indices)):
    print(sorted_indices[i])

torch.Size([20, 4967])
tensor([1540, 4329, 3451,  ..., 4664, 1282, 2311])
tensor([1105, 1789,  252,  ..., 4664, 4532, 2311])
tensor([2039, 2848,  535,  ..., 1803, 4664, 2311])
tensor([4715, 4085, 1498,  ..., 4664, 2311, 4532])
tensor([1829, 2258, 2489,  ..., 1679, 1282, 2311])
tensor([3317,  428, 4174,  ..., 1282, 4664, 2311])
tensor([  52, 3724, 4536,  ..., 1679, 4664, 2311])
tensor([2700, 4488, 1281,  ..., 4532, 4664, 2311])
tensor([3581, 4813, 3023,  ..., 1185, 1282, 2311])
tensor([2324, 1814, 1170,  ..., 4664, 2311, 4532])
tensor([4813, 2319, 3519,  ..., 4664, 1282, 2311])
tensor([2562, 3597, 4371,  ..., 4532, 4664, 2311])
tensor([3657, 4465,  397,  ..., 2390, 2311, 4664])
tensor([  45, 2239, 4382,  ..., 4664, 1282, 2311])
tensor([3914, 4298, 2929,  ..., 1185, 1282, 2311])
tensor([1099, 3857, 1504,  ..., 1282, 4664, 2311])
tensor([ 305, 2020, 4799,  ..., 4664, 4532, 2311])
tensor([3613, 1788,  949,  ..., 4664, 1282, 2311])
tensor([2483, 2755, 2464,  ..., 1282, 4664, 2311])
tensor([

In [None]:
f = open("transformer_rank_list.txt", "w")
for i in range(len(sorted_indices)):
    f.write("Q" + str(i + 1) + ": ")    
    f.write(" ".join([str(x.item()) for x in sorted_indices[i]]))
    f.write("\n")
f.close()

In [None]:
import os
import glob
import numpy as np
import shutil


download_path = "datasets"
path_gallery = os.path.join(download_path, "gallery_4186")
all_indices = []

for query_img_no in query_imgs_no:
    # Create a folder for the query image
    query_img_folder = os.path.join(download_path, f"image_{query_img_no}")
    os.makedirs(query_img_folder, exist_ok=True)
    
    # Get the top 10 indices for this query image
    top_10_indices = sorted_indices[query_imgs_no.index(query_img_no), :10].tolist()
    
    # Copy the top 10 most similar gallery images to the query image's folder
    for index in top_10_indices:
        gallery_img_name = gallery_imgs_no[index] + ".jpg" # Assuming gallery images are in .jpg format
        src_path = os.path.join(path_gallery, gallery_img_name)
        dst_path = os.path.join(query_img_folder, gallery_img_name)
        shutil.copy(src_path, dst_path)

    all_indices.extend(top_10_indices)

# Write all the indices to a single text file
with open(os.path.join(download_path, "transformer_top10.txt"), 'w') as file:
    for i, index in enumerate(all_indices):
        file.write(f"Q{i+1}: {index}\n")
