In [1]:
import os
import torch.nn.functional as F
from transformers import AutoProcessor, CLIPModel
import torch
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
model.eval().to(0)

In [3]:
def extract_img_embed(image:Image) -> torch.Tensor: # 1 x 768
    inputs = processor(images=image, return_tensors="pt").to("cuda")
    image_features = model.get_image_features(**inputs)
    return image_features

In [4]:
source_dir = '/home/steve/Datasets/OpenEarthMap-FSS/temptest/images'
supportset_dir = '/home/steve/Datasets/OpenEarthMap-FSS/temptest'

In [5]:
s_image_list = []
s_image_features = []
for file in sorted(os.listdir(source_dir)):
    img = Image.open(os.path.join(source_dir, file))
    with torch.no_grad():
        feat = extract_img_embed(img)
    s_image_list.append(file)
    s_image_features.append(feat)
s_image_features = torch.cat(s_image_features, dim=0)
print(s_image_features.shape)

torch.Size([80, 768])


In [6]:
sset_image_features = {}
for i in range(8, 12):
    t_image_list = []
    t_image_features = []
    for file in os.listdir(os.path.join(supportset_dir, str(i), 'images')):
        img = Image.open(os.path.join(supportset_dir, str(i), 'images', file))
        with torch.no_grad():
            feat = extract_img_embed(img)
        t_image_list.append(file)
        t_image_features.append(feat)

    t_image_features = torch.cat(t_image_features, dim=0)
    sset_image_features[i] = {}
    sset_image_features[i]['feat'] = t_image_features
    sset_image_features[i]['list'] = t_image_list

Calculate cosine similarity

In [22]:
threshold = {
    8: [0.65, 0.7],
    9: [0.72, 0.82],
    10: [0.7, 0.73],
    11: [0.67, 0.8]
}

out = {i: {'filter_list': [], 'force_overlay': []} for i in range(8, 12)}
for i in range(8, 12):
    sset_feat = sset_image_features[i]['feat']
    sim_matrix = F.cosine_similarity(sset_feat.unsqueeze(1), s_image_features.unsqueeze(0), dim=2)
    for j in range(len(s_image_list)):
        if sim_matrix[:, j].mean().item() >= threshold[i][1]:
            out[i]['force_overlay'].append(s_image_list[j])
            # print('overlay', i, j, s_image_list[j], sim_matrix[:, j], sim_matrix[:, j].mean().item())
        elif sim_matrix[:, j].mean().item() <= threshold[i][0]:
            out[i]['filter_list'].append(s_image_list[j])
            # print('exclude', i, j, s_image_list[j], sim_matrix[:, j], sim_matrix[:, j].mean().item())

In [23]:
import json
with open('filter.json', 'w') as f:
    json.dump(out, f)