In [1]:
data_root = '../datasets/ILSVRC2012/'
ckp_root = '../checkpoints/imagenet100/'
code_root = '../'

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from collections import Counter, defaultdict
from numpy import linalg as LA
from torchvision import transforms
import torch
import torchvision
from tqdm import tqdm
from sklearn.svm import LinearSVC
from torch.utils.data.sampler import SubsetRandomSampler
import sys
sys.path.append(code_root)
from models.resnet import resnet50
from utils import cvt_state_dict
from data.LT_Dataset import LT_Dataset

In [3]:
transform_train = transforms.Compose([
          transforms.RandomResizedCrop(224),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
      ])
transform_test = transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(224),
          transforms.ToTensor(),
      ])

txt_test = code_root + "split/imagenet-100/ImageNet_100_test.txt"
txt_train_lt = code_root + "split/imagenet-100/imageNet_100_LT_train.txt"
txt_train_fs = code_root + "split/imagenet-100/imageNet_100_sub_balance_train_0.01.txt"


train_datasets_lt = LT_Dataset(root=data_root, txt=txt_train_lt, transform=transform_test)
train_datasets_fs = LT_Dataset(root=data_root, txt=txt_train_fs, transform=transform_test)
test_datasets = LT_Dataset(root=data_root, txt=txt_test, transform=transform_test)


batch_size = 64
train_loader_lt = torch.utils.data.DataLoader(train_datasets_lt, num_workers=0, batch_size=batch_size)
train_loader_fs = torch.utils.data.DataLoader(train_datasets_fs, num_workers=0, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_datasets, num_workers=0, batch_size=batch_size)

In [5]:
num_class = 100
bnNameCnt = -1
device = 'cuda'
model = resnet50(num_classes=num_class, imagenet='imagenet-100')
model = model.to(device)

path_checkpoint = ckp_root + f'simclr_TS.pt'
checkpoint = torch.load(path_checkpoint, map_location="cpu")
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
elif 'P_state' in checkpoint:
    state_dict = checkpoint['P_state']
else:
    state_dict = checkpoint

try:
    in_features = model.fc.in_features
except AttributeError:
    in_features = None
state_dict = cvt_state_dict(state_dict, bnNameCnt, in_features, num_classes=num_class)

model.load_state_dict(state_dict)
# print('read checkpoint {}'.format(path_checkpoint))

<All keys matched successfully>

In [7]:
# extract train long tail split features

model.eval()
train_LT_features = None
train_LT_gt = []
with torch.no_grad():
    for data, target, _ in tqdm(train_loader_lt):
        data, target = data.to(device), target.to(device)
        feats = model.eval()(data, features=True)
        if train_LT_features is None:
            train_LT_features = feats.detach().cpu()
        else:
            train_LT_features = torch.cat((train_LT_features, feats.detach().cpu()), dim=0)
        train_LT_gt += target.cpu().numpy().tolist()

data_normed = LA.norm(train_LT_features.numpy(), 2, axis=-1)
train_LT_features = train_LT_features.numpy() / data_normed.reshape(-1, 1)

100%|██████████| 191/191 [06:03<00:00,  1.90s/it]


In [8]:
# extract train few-shot split features

model.eval()
train_FS_features = None
train_FS_gt = []
with torch.no_grad():
    for data, target, _ in tqdm(train_loader_fs):
        data, target = data.to(device), target.to(device)
        feats = model.eval()(data, features=True)
        if train_FS_features is None:
            train_FS_features = feats.detach().cpu()
        else:
            train_FS_features = torch.cat((train_FS_features, feats.detach().cpu()), dim=0)
        train_FS_gt += target.cpu().numpy().tolist()

data_normed = LA.norm(train_FS_features.numpy(), 2, axis=-1)
train_FS_features = train_FS_features.numpy() / data_normed.reshape(-1, 1)

100%|██████████| 19/19 [00:38<00:00,  2.04s/it]


In [10]:
# extract test features

model.eval()
test_features = None
test_gt = []
with torch.no_grad():
    for data, target, _ in tqdm(test_loader):
        data, target = data.to(device), target.to(device)
        feats = model.eval()(data, features=True)
        if test_features is None:
            test_features = feats.detach().cpu()
        else:
            test_features = torch.cat((test_features, feats.detach().cpu()), dim=0)
        test_gt += target.cpu().numpy().tolist()
        
data_normed = LA.norm(test_features.numpy(), 2, axis=-1)
test_features = test_features.numpy() / data_normed.reshape(-1, 1)

100%|██████████| 79/79 [02:36<00:00,  1.99s/it]


In [11]:
dist_tmp = torch.cdist(torch.tensor(test_features), torch.tensor(train_LT_features))
predicted = torch.argsort(dist_tmp, dim=1)
sorted_dist_test_lt = predicted.numpy()

In [12]:
print(f'ImageNet-100 Evaluation')

class_acc = []
loc_gt = np.array(test_gt)
loc_argsort = sorted_dist_test_lt
train_loc_gt = np.array(train_LT_gt)
for cl in range(num_class):
    predictions = loc_argsort[loc_gt == cl, 0]
    result = (train_loc_gt[predictions] == cl).sum() / (loc_gt == cl).sum()
    class_acc.append(result)
class_acc.append(np.mean(class_acc))
print('KNN@1 ', f'{class_acc[-1]*100:.2f}')

knn=10
class_acc = []
loc_gt = np.array(test_gt)
loc_argsort = sorted_dist_test_lt
train_loc_gt = np.array(train_LT_gt)
for cl in range(num_class):
    predict_mat = np.zeros(((loc_gt == cl).sum(), num_class))
    predictions = loc_argsort[loc_gt == cl, :knn]
    for cl2 in range(num_class):
        result2 = (train_loc_gt[predictions] == cl2).sum(1)
        predict_mat[:, cl2] = result2
    result = (np.argmax(predict_mat, 1) == cl).sum() / (loc_gt == cl).sum()   
    class_acc.append(result)
class_acc.append(np.mean(class_acc))
print('KNN@10', f'{class_acc[-1]*100:.2f}')

local_data = train_FS_features
local_gt = np.array(train_FS_gt)
fs_svm = LinearSVC(random_state=42)
fs_svm.fit(local_data, local_gt)
local_test = test_features
local_gt_test = np.array(test_gt)
predictions = fs_svm.predict(local_test)
acc = (predictions == local_gt_test).sum() / len(local_gt_test)
print('FS    ', f'{acc * 100:.2f}')

local_data = train_LT_features
local_gt = np.array(train_LT_gt)
local_counter = Counter(local_gt)
np.random.seed(42)
lt_svm = LinearSVC(random_state=42)
lt_svm.fit(local_data, local_gt)
predictions = lt_svm.predict(local_test)
acc = (predictions == local_gt_test).sum() / len(local_gt_test)
print('LT    ', f'{acc * 100:.2f}')

ImageNet-100 Evaluation
KNN@1  38.38
KNN@10 38.98
FS     45.18
LT     47.26
