In [None]:
import os
import torch.nn.functional as F
from transformers import AutoProcessor, CLIPModel
from huggingface_hub import hf_hub_download
import torch, open_clip
from PIL import Image

In [None]:
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
model.eval().to(0)

In [None]:
def extract_img_embed(image:Image) -> torch.Tensor: # 1 x 768
    inputs = processor(images=image, return_tensors="pt").to("cuda")
    image_features = model.get_image_features(**inputs)
    return image_features

In [None]:
source_dir = '/home/steve/Datasets/OpenEarthMap-FSS/testset/images'
image_dir = '/home/steve/Datasets/OpenEarthMap-FSS/trainset/images'

In [None]:
s_image_list = []
s_image_features = []
for file in os.listdir(source_dir):
    img = Image.open(os.path.join(source_dir, file))
    with torch.no_grad():
        feat = extract_img_embed(img)
    s_image_list.append(file)
    s_image_features.append(feat)
s_image_features = torch.cat(s_image_features, dim=0)

In [None]:
t_image_list = []
t_image_features = []
for file in os.listdir(image_dir):
    img = Image.open(os.path.join(image_dir, file))
    with torch.no_grad():
        feat = extract_img_embed(img)
    t_image_list.append(file)
    t_image_features.append(feat)

t_image_features = torch.cat(t_image_features, dim=0)

Calculate cosine similarity

In [None]:
sim_matrix = F.cosine_similarity(s_image_features.unsqueeze(1), t_image_features.unsqueeze(0), dim=2)
print(sim_matrix.shape)

In [None]:
out_dir = '/home/steve/Datasets/OpenEarthMap-FSS/trainset/similarity-vit'
os.makedirs(out_dir, exist_ok=True)

In [None]:
# for each row sort the similarity matrix and get the top k
for s_i in range(sim_matrix.shape[0]):
    sim_row = sim_matrix[s_i]
    sim_row, sim_row_indices = torch.sort(sim_row, descending=True)

    name = s_image_list[s_i].split('.')[0]
    os.makedirs(os.path.join(out_dir, name), exist_ok=True)
    for i in range(10):
        img = Image.open(os.path.join(image_dir, t_image_list[sim_row_indices[i]]))
        img.save(os.path.join(out_dir, name, f'{i+1}_sim_{sim_row[i]:.3f}_{t_image_list[sim_row_indices[i]]}'))
    #save the original image too
    img = Image.open(os.path.join(source_dir, s_image_list[s_i]))
    img.save(os.path.join(out_dir, name, f'0_{s_image_list[s_i]}'))

In [None]:
out_json_dict = {}

for s_i in range(sim_matrix.shape[0]):
    sim_row = sim_matrix[s_i]
    sim_row, sim_row_indices = torch.sort(sim_row, descending=True)

    name = s_image_list[s_i]
    out_json_dict[name] = []
    for i in range(10):
        out_json_dict[name].append(t_image_list[sim_row_indices[i]])

In [None]:
import json
with open('mapping_remoteclip.json', 'w') as f:
    json.dump(out_json_dict, f)