In [1]:
import numpy as np
import glob
# import cv2
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
download_path = "datasets"  # change to your own download path
path_query = download_path + "/query_img_4186"
path_query_txt = download_path + "/query_img_box_4186"

# path_query_txt is the directory to the bounding box information of the instance(s) for the query images
path_gallery = download_path + "/gallery_4186"

name_query = glob.glob(path_query + "/*.jpg")
num_query = len(name_query)

name_box = glob.glob(path_query_txt+"/*.txt")

name_gallery = glob.glob(path_gallery + "/*.jpg")
num_gallery = len(name_gallery)
record_all = np.zeros((num_query, len(name_gallery)))

query_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_query + "/*.jpg")]
gallery_imgs_no = [x.split("/")[-1][:-4] for x in glob.glob(path_gallery + "/*.jpg")]

In [3]:
# from transformers import AutoImageProcessor, AutoModel

# processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
# model = AutoModel.from_pretrained("facebook/dinov2-base")

In [4]:
import timm
import torch


model = timm.create_model(
    'efficientnet_b0',
    pretrained=True,
    num_classes=0, 
)
model = model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from torchvision.transforms import Resize, Compose, ToTensor
from torch.utils.data import Dataset, DataLoader


class QueryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, bounding_box_path,transform):        
        self.image_paths = image_paths
        self.transform = transform
        self.bounding_box_path = bounding_box_path

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        bounding_box = self.bounding_box_path[idx]
        x, y, w, h = np.loadtxt(bounding_box)

        # image = image.crop((x, y, x+w, y+h))        
        image = self.transform(image)
        return image

class GalleryDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform):        
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)        
        image = self.transform(image)
        return image  



In [6]:
from torchvision.transforms import Resize, Compose, Normalize, CenterCrop
from torchvision.transforms.functional import InterpolationMode
import torch

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

In [7]:
dataset = QueryDataset(name_query, name_box, transform=transforms)
data_loader = DataLoader(dataset, batch_size=4, num_workers=4)  

dino_query_embeddings = []
with torch.no_grad():
    for images in data_loader:
        outputs = model.forward_features(images.to(device))
        # outputs = model.forward_features(images)
        dino_query_embeddings.append(outputs)
        # torch.cuda.empty_cache()
        

In [8]:
print(dino_query_embeddings[0].shape)

torch.Size([4, 1280, 7, 7])


In [9]:
# # del model
# # gc.collect()
# # torch.cuda.empty_cache()

# model = timm.create_model(
#     'samvit_base_patch16.sa1b',
#     pretrained=True,
#     num_classes=0, 
# )
# model = model.eval()
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = model.to(device)

In [9]:
dataset = GalleryDataset(name_gallery, transform=transforms)
data_loader = DataLoader(dataset, batch_size=4, num_workers=4)  

dino_gallery_embeddings = []
with torch.no_grad():
    for images in data_loader:
        outputs = model.forward_features(images.to(device))
        dino_gallery_embeddings.append(outputs)
        # torch.cuda.empty_cache()

In [10]:
print(dino_gallery_embeddings[0].shape)
print(dino_query_embeddings[0].shape)

torch.Size([4, 1280, 7, 7])
torch.Size([4, 1280, 7, 7])


In [11]:

# dino_query_embeddings = []

# for i, query_img_no in enumerate(query_imgs_no[:1]):    
#     per_query_name = path_query + "/" + str(query_img_no) + ".jpg"
#     per_query_txt_name = path_query_txt + "/" + str(query_img_no) + ".txt"
#     print(per_query_name)
#     x, y, w, h = np.loadtxt(per_query_txt_name)
#     per_query = cv2.imread(per_query_name)    
#     per_query = cv2.cvtColor(per_query, cv2.COLOR_BGR2RGB)
#     per_query = per_query[int(y):int(y+h), int(x):int(x+w)]
#     inputs = processor(images=per_query, return_tensors="pt")
#     outputs = model(**inputs)
#     dino_query_embeddings.append(outputs.last_hidden_state)

In [12]:
# dino_gallery_embeddings = []

# for j, gallery_img_no in enumerate(gallery_imgs_no):
#     per_gallery_name = path_gallery + "/" + str(gallery_img_no) + ".jpg"
#     per_gallery = cv2.imread(per_gallery_name)
#     per_gallery = cv2.cvtColor(per_gallery, cv2.COLOR_BGR2RGB)
#     inputs = processor(images=per_gallery, return_tensors="pt")
#     outputs = model(**inputs)
#     dino_gallery_embeddings.append(outputs.last_hidden_state)


In [14]:
import torch.nn.functional as F

dino_query_embeddings = torch.stack([qe.flatten() for qe in dino_query_embeddings])
dino_gallery_embeddings = torch.stack([ge.flatten() for ge in dino_gallery_embeddings])


print(dino_query_embeddings[0].shape)
print(dino_gallery_embeddings[0].shape)

torch.Size([250880])
torch.Size([250880])


In [15]:
dino_query_embeddings_norm = dino_query_embeddings / dino_query_embeddings.norm(
    dim=1, keepdim=True
)
dino_gallery_embeddings_norm = dino_gallery_embeddings / dino_gallery_embeddings.norm(
    dim=1, keepdim=True
)

In [16]:
print(dino_query_embeddings_norm[0].shape)  
print(dino_gallery_embeddings_norm[0].shape)

torch.Size([250880])
torch.Size([250880])


In [17]:
batch_size = 4  # Adjust this based on your available memory
num_batches = (dino_query_embeddings_norm.shape[0] + batch_size - 1) // batch_size

dino_cosine_similarities = []

for i in range(num_batches):
    start = i * batch_size
    end = start + batch_size
    batch_query = dino_query_embeddings_norm[start:end].unsqueeze(1)
    batch_similarities = F.cosine_similarity(
        batch_query, dino_gallery_embeddings_norm.unsqueeze(0), dim=2
    )
    dino_cosine_similarities.append(batch_similarities)

dino_cosine_similarities = torch.cat(dino_cosine_similarities)

: 

In [14]:
# Compute cosine similarity
# dino_cosine_similarities = F.cosine_similarity(
#     dino_query_embeddings_norm.unsqueeze(1), dino_gallery_embeddings_norm.unsqueeze(0), dim=2
# )

# Sort and select top similarities for each query
dino_cosine = []
for i in range(len(dino_query_embeddings)):
    dino_cosine.append(torch.argsort(dino_cosine_similarities[i, :], descending=True))

torch.Size([250880])


: 

In [None]:
f = open("rank_list.txt", "w")
for i in range(len(dino_cosine)):
    f.write("Q" + str(i + 1) + ": ")    
    f.write(" ".join([str(x.item()) for x in dino_cosine[i]]))
    f.write("\n")
f.close()