In [1]:
import os
import sys
import torch

from PIL import Image
from matplotlib import pyplot as plt

In [5]:
# set up args
import argparse

parser = argparse.ArgumentParser()
# for dataset
parser.add_argument('--data', default='./dataset/dataset_v2')
parser.add_argument('--train_list', default='./dataset/dataset_v2/sample.txt')
parser.add_argument('--query_list', default='./dataset/dataset_v2/query.txt')
# for model
parser.add_argument('--num_class', default=181)
parser.add_argument('--softmax', default=0, type=int)
parser.add_argument('--s', default=30.0, type=float)
parser.add_argument('--m', default=0.15, type=float)
# for training
parser.add_argument('--max_epochs', default=10)
parser.add_argument('--batch_size', default=128)
parser.add_argument('--lr', default=3e-4)
parser.add_argument('--weight_decay', default=1e-4)
# for visualization
parser.add_argument('--log_steps', default=1)
parser.add_argument('--ckpt_dir', default="output/retrieval/vit_lr0.0001_0.0001_bs128_epoch30_arcface_m0.15_s30_translate_cn/")
parser.add_argument('--load_epoch', default=30)

args = parser.parse_args(args=[])
print(args)

Namespace(batch_size=128, ckpt_dir='output/retrieval/vit_lr0.0001_0.0001_bs128_epoch30_arcface_m0.15_s30_translate_cn/', data='./dataset/dataset_v2', load_epoch=30, log_steps=1, lr=0.0003, m=0.15, max_epochs=10, num_class=181, query_list='./dataset/dataset_v2/query.txt', s=30.0, softmax=0, train_list='./dataset/dataset_v2/sample.txt', weight_decay=0.0001)


In [6]:
# set up model
from supervised.model import Model, DOLG, TransformerModel

model = TransformerModel(num_class=args.num_class, args=args).cuda()
model.load_state_dict(torch.load(os.path.join(args.ckpt_dir, "model_{:0>2d}.ckpt".format(args.load_epoch - 1))))
model.eval()

TransformerModel(
  (backbone): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln

In [7]:
# set up dataset
from torch.utils.data import DataLoader
from supervised.dataset import Dataset

query_dataset = Dataset(args.data, args.query_list, mode="test", visual=True)
query_loader = DataLoader(query_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, drop_last=False)

train_dataset = Dataset(args.data, args.train_list, mode="test", visual=True)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, drop_last=False)

In [8]:
# for images, labels, im_pth, coord in train_loader:
#     plt.figure(figsize=(8 * 4, 2 * 4))
#     for i in range(8):
#         im = Image.open(im_pth[i])
#         plt.subplot(2, 8, i+1)
#         plt.imshow(im)
#         plt.subplot(2, 8, i+1+8)
#         plt.imshow(images[i].permute(1, 2, 0).numpy())
#     plt.show()
#     plt.close()
#     break

In [9]:
# extract feature
with torch.no_grad():
    feat_dic = {}
    for images, labels, im_pth, coord in train_loader:
        images = images.cuda()
        labels = labels.cuda()
        features, _ = model(images, labels, args)
        features = features / torch.norm(features, dim=-1, keepdim=True)
        features = features.detach().cpu()
        for i in range(images.shape[0]):
            feat_dic[im_pth[i]] = {
                "feat": features[i],
                "label": labels[i],
                "coord": coord[i]
            }
            
with torch.no_grad():
    q_feat_dic = {}
    for images, labels, im_pth, coord in query_loader:
        images = images.cuda()
        labels = labels.cuda()
        features, _ = model(images, labels, args)
        features = features / torch.norm(features, dim=-1, keepdim=True)
        features = features.detach().cpu()
        for i in range(images.shape[0]):
            q_feat_dic[im_pth[i]] = {
                "feat": features[i],
                "label": labels[i],
                "coord": coord[i]
            }

In [14]:
import numpy as np

features = []
labels = []
im_pth_list = []
for im_pth in feat_dic.keys():
    features.append(feat_dic[im_pth]["feat"])
    labels.append(feat_dic[im_pth]["label"])
    im_pth_list.append(im_pth)
features = torch.stack(features)
labels = torch.stack(labels)

q_features = []
q_labels = []
q_im_pth_list = []
for im_pth in q_feat_dic.keys():
    if q_feat_dic[im_pth]["label"] != -1:
        q_features.append(q_feat_dic[im_pth]["feat"])
        q_labels.append(q_feat_dic[im_pth]["label"])
        q_im_pth_list.append(im_pth)
q_features = torch.stack(q_features)
q_labels = torch.stack(q_labels)

sim = torch.matmul(q_features, features.T) 
vec_ranks = torch.argsort(-sim, dim=1)

print(sim.shape)

torch.Size([905, 3950])


In [12]:
# save top 10 label
import pickle

_, b = sim.topk(10)
top10_labels = torch.gather(labels.cpu().expand_as(sim), dim=1, index=b)
print(top10_labels.shape)
query_dic = {}
certify_cnt = 0
for i, im_pth in enumerate(q_im_pth_list):
    query_dic[im_pth] = {
        "label": q_feat_dic[im_pth]["label"].cpu().numpy(),
        "coord": q_feat_dic[im_pth]["coord"].numpy(),
        "top10": top10_labels[i]
    }
    certify_cnt += torch.unique(top10_labels[i]).shape[0]
print("N:", certify_cnt / len(q_im_pth_list))
with open("query_dic.pk", "wb") as f:
    pickle.dump(query_dic, f)

torch.Size([1067, 10])
N: 1.1152764761012184


In [15]:
topk = [1, 5, 10, 20]
for k in topk:
    _, b =sim.topk(k)
    target = torch.gather(labels.cpu().expand_as(sim), dim=1, index=b)
    acc = torch.mean((target == q_labels.cpu().reshape(-1, 1).expand_as(target)).sum(dim=-1).float() / k)
    print("precision@{} : {:.4f}".format(k, acc))

precision@1 : 0.9945
precision@5 : 0.9938
precision@10 : 0.9904
precision@20 : 0.8951


In [16]:
topk = [1, 5, 10, 20]
for k in topk:
    _, b =sim.topk(k)
    target = torch.gather(labels.cpu().expand_as(sim), dim=1, index=b)
    acc = torch.mean(((target == q_labels.cpu().reshape(-1, 1).expand_as(target)).sum(dim=-1) > 0).float())
    print("topk acc@{} : {:.4f}".format(k, acc))

topk acc@1 : 0.9945
topk acc@5 : 0.9967
topk acc@10 : 0.9967
topk acc@20 : 0.9989


In [None]:
import random
from PIL import Image

def calculate_overlap(dic1, dic2):
    if dic1["label"] != dic2["label"] or dic1["label"] == -1 or dic2["label"] == -1:
        overlap_area = 0
    else:
        x1, y1, x2, y2 = dic1["coord"]
        x3, y3, x4, y4 = dic2["coord"]
        x_overlap = max(0, min(x2, x4) - max(x1, x3))
        y_overlap = max(0, min(y2, y4) - max(y1, y3))
        overlap_area = x_overlap * y_overlap
        overlap_area = float(overlap_area) / ((x2 - x1) * (y2 - y1))
    return overlap_area
        
for i in range(10):
    idx = random.randint(0, q_features.shape[0]-1)
    plt.figure(figsize=(40, 4))
    plt.subplot(1, 10, 1)
    im_pth = q_im_pth_list[idx]
    q_dic = q_feat_dic[im_pth]
    im = Image.open(im_pth)
    plt.imshow(im)
    plt.title("Label:{}".format(q_dic["label"].item()))
    for j in range(9):
        plt.subplot(1, 10, j + 2)
        im_pth = im_pth_list[vec_ranks[idx][j]]
        im = Image.open(im_pth)
        dic = feat_dic[im_pth]
        overlap = calculate_overlap(q_dic, dic)
        plt.imshow(im)
        plt.title("Label:{}, Sim:{:.4f}, Overlap:{:.2f}".format(
            dic["label"].item(), sim[idx][vec_ranks[idx][j]].item(), overlap
        ))