In [1]:
import time
import json
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.nn import functional as F
from PIL import Image
from mmcv.runner import load_checkpoint
from openselfsup.models import build_model

In [2]:
class Trans(object):
    def __init__(self, cfg):
        global_trans_list = [
            T.Resize(256),
            T.CenterCrop(224)
        ]
        self.global_transform = T.Compose(global_trans_list)
        self.img_transform = T.Compose([
            T.ToTensor(),
            T.Normalize(**cfg.img_norm_cfg)])

In [3]:
# forward global image for knn retrieval
def global_forward(img, model):
    x = torch.stack(img).cuda()
    with torch.no_grad():
        x = model.backbone(x)
        feats = model.neck(x)[0]
        feats_norm = F.normalize(feats, dim=1)
    return feats_norm.detach()

In [4]:
## use this part for coco
work_dirs = '../work_dirs/selfsup/'
global_config = '../configs/selfsup/orl/coco/stage1/r50_bs512_ep800_extract_feature.py'
global_checkpoint = work_dirs + 'orl/coco/stage1/r50_bs512_ep800/epoch_800.pth'
feat_bank_npy = work_dirs + 'orl/coco/stage1/r50_bs512_ep800_extract_feature/feature_epoch_800.npy'

In [5]:
# ## use this part for coco+
# work_dirs = '../work_dirs/selfsup/'
# global_config = '../configs/selfsup/orl/cocoplus/stage1/r50_bs512_ep800_extract_feature.py'
# global_checkpoint = work_dirs + 'orl/cocoplus/stage1/r50_bs512_ep800/epoch_800.pth'
# feat_bank_npy = work_dirs + 'orl/cocoplus/stage1/r50_bs512_ep800_extract_feature/feature_epoch_800.npy'

In [6]:
global_cfg = mmcv.Config.fromfile(global_config)

# build the model and load checkpoint
global_model = build_model(global_cfg.model)
load_checkpoint(global_model, global_checkpoint, map_location='cpu')
global_model = global_model.cuda()
global_model.eval()

BYOL(
  (online_net): Sequential(
    (0): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inp

In [7]:
## use this part for coco (train2017)
# load data
train_json = '../data/coco/annotations/instances_train2017.json'
train_root = '../data/coco/train2017/'
with open(train_json, 'r') as json_file:
    data = json.load(json_file)
train_fns = [train_root + item['file_name'] for item in data['images']]
imgids = [item['id'] for item in data['images']]
knn_imgids = []
# batch processing
trans = Trans(global_cfg)
batch = 512
keys = 10
feats_bank = torch.from_numpy(np.load(feat_bank_npy)).cuda()
for i in range(0, len(train_fns), batch):
    print("[INFO] processing batch: {}".format(i + 1))
    start = time.time()
    if (i + batch) < len(train_fns):
        images = [Image.open(fn).convert('RGB') for fn in train_fns[i:i + batch]]
    else:
        images = [Image.open(fn).convert('RGB') for fn in train_fns[i:len(train_fns)]]
    global_images = [trans.global_transform(img) for img in images]
    global_tensors = [trans.img_transform(img) for img in global_images]
    # retrieve knn images
    query_feats = global_forward(global_tensors, global_model)
    similarity = torch.mm(query_feats, feats_bank.permute(1, 0))
    I = torch.topk(similarity, keys + 1, dim=1)[1].cpu()
    I = I[:,1:]  # exclude itself (i.e., 1st nn)
    knn_list = I.numpy().tolist()
    [knn_imgids.append(knn) for knn in knn_list]
    end = time.time()
    print("[INFO] batch {} took {:.4f} seconds".format(i + 1, end - start))

[INFO] processing batch: 1
[INFO] batch 1 took 10.0100 seconds
[INFO] processing batch: 513
[INFO] batch 513 took 10.5673 seconds
[INFO] processing batch: 1025
[INFO] batch 1025 took 10.6541 seconds
[INFO] processing batch: 1537
[INFO] batch 1537 took 10.0347 seconds
[INFO] processing batch: 2049
[INFO] batch 2049 took 10.2120 seconds
[INFO] processing batch: 2561
[INFO] batch 2561 took 9.6567 seconds
[INFO] processing batch: 3073
[INFO] batch 3073 took 9.6550 seconds
[INFO] processing batch: 3585
[INFO] batch 3585 took 10.1457 seconds
[INFO] processing batch: 4097
[INFO] batch 4097 took 9.9481 seconds
[INFO] processing batch: 4609
[INFO] batch 4609 took 10.0417 seconds
[INFO] processing batch: 5121
[INFO] batch 5121 took 9.7011 seconds
[INFO] processing batch: 5633
[INFO] batch 5633 took 9.9319 seconds
[INFO] processing batch: 6145
[INFO] batch 6145 took 9.7829 seconds
[INFO] processing batch: 6657
[INFO] batch 6657 took 9.7906 seconds
[INFO] processing batch: 7169
[INFO] batch 7169 t

In [8]:
# ## use this part for coco+ (train2017+unlabeled2017)
# # load data
# train_json = '../data/coco/annotations/instances_train2017.json'
# unlabeled_json = '../data/coco/annotations/image_info_unlabeled2017.json'
# train_root = '../data/cocoplus/trainplus2017/'
# with open(train_json, 'r') as f1:
#     data = json.load(f1)
# with open(unlabeled_json, 'r') as f2:
#     unlabeled_data = json.load(f2)
# data['images'].extend(unlabeled_data['images'])
# train_fns = [train_root + item['file_name'] for item in data['images']]
# imgids = [item['id'] for item in data['images']]
# knn_imgids = []
# # batch processing
# trans = Trans(global_cfg)
# batch = 512
# keys = 10
# feats_bank = torch.from_numpy(np.load(feat_bank_npy)).cuda()
# for i in range(0, len(train_fns), batch):
#     print("[INFO] processing batch: {}".format(i + 1))
#     start = time.time()
#     if (i + batch) < len(train_fns):
#         images = [Image.open(fn).convert('RGB') for fn in train_fns[i:i + batch]]
#     else:
#         images = [Image.open(fn).convert('RGB') for fn in train_fns[i:len(train_fns)]]
#     global_images = [trans.global_transform(img) for img in images]
#     global_tensors = [trans.img_transform(img) for img in global_images]
#     # retrieve knn images
#     query_feats = global_forward(global_tensors, global_model)
#     similarity = torch.mm(query_feats, feats_bank.permute(1, 0))
#     I = torch.topk(similarity, keys + 1, dim=1)[1].cpu()
#     I = I[:,1:]  # exclude itself (i.e., 1st nn)
#     knn_list = I.numpy().tolist()
#     [knn_imgids.append(knn) for knn in knn_list]
#     end = time.time()
#     print("[INFO] batch {} took {:.4f} seconds".format(i + 1, end - start))

In [10]:
# save image id and knn image id as a json file
data_new = {}
num_image = 118287  # 118287 for coco, 241690 for coco+ 
save_path = '../data/coco/meta/train2017_{}nn_instance.json'.format(keys)
# save_path = '../data/cocoplus/meta/trainplus2017_{}nn_instance.json'.format(keys)
assert len(imgids) == len(knn_imgids) == len(train_fns) == num_image, \
    "Mismatch the total number of images in training set, got: {}".format(len(knn_imgids))
# dict
info = {}
image_info = {}
pseudo_anno = {}
info['knn_image_num'] = keys
image_info['file_name'] = [item['file_name'] for item in data['images']]
image_info['id'] = [item['id'] for item in data['images']]
pseudo_anno['image_id'] = imgids
pseudo_anno['knn_image_id'] = knn_imgids
data_new['info'] = info
data_new['images'] = image_info
data_new['pseudo_annotations'] = pseudo_anno
with open(save_path, 'w') as f:
    json.dump(data_new, f)
print("[INFO] image-level knn json file has been saved to {}".format(save_path))

[INFO] image-level knn json file has been saved to ../data/coco/meta/train2017_10nn_instance.json
