In [1]:
data_root = '../datasets/cifar100/'
ckp_root = '../checkpoints/cifar100/'
code_root = '../'

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from collections import Counter, defaultdict
from numpy import linalg as LA
from torchvision import transforms
import torch
from sklearn.svm import LinearSVC
from torch.utils.data.sampler import SubsetRandomSampler
import sys
sys.path.append(code_root)
from models.resnet import resnet18
from utils import cvt_state_dict
import torchvision
from tqdm import tqdm

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])


train_datasets = torchvision.datasets.CIFAR100(root=data_root, 
                                              train=True, 
                                              download=True, 
                                              transform=transform_train)

Files already downloaded and verified


In [3]:
val_splits = {
1: 'cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy',
2: 'cifar100_imbSub_with_subsets/cifar100_split2_D_i.npy',
3: 'cifar100_imbSub_with_subsets/cifar100_split3_D_i.npy'
}

In [4]:
split = 1  # choose split to test [1,2,3]
batch_size = 64
train_idx = list(np.load(code_root + 'split/{}'.format(val_splits[split])))
train_sampler = SubsetRandomSampler(train_idx)
long_tail_loader = torch.utils.data.DataLoader(
    train_datasets,
    batch_size=batch_size, sampler=train_sampler, num_workers=0)
    
testset = torchvision.datasets.CIFAR100(root=data_root, train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [5]:
num_class = 100
bnNameCnt = -1
device = 'cuda'
model = resnet18(pretrained=False, num_classes=num_class)
model = model.to(device)

path_checkpoint = ckp_root + f'simclr_TS_SP{split}.pt'
checkpoint = torch.load(path_checkpoint, map_location="cpu")
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
elif 'P_state' in checkpoint:
    state_dict = checkpoint['P_state']
else:
    state_dict = checkpoint

try:
    in_features = model.fc.in_features
except AttributeError:
    in_features = None
state_dict = cvt_state_dict(state_dict, bnNameCnt, in_features, num_classes=num_class)

model.load_state_dict(state_dict)
# print('read checkpoint {}'.format(path_checkpoint))

<All keys matched successfully>

In [12]:
# extract train long tail split features

model.eval()
train_LT_features = None
train_LT_gt = []
with torch.no_grad():
    for data, target in tqdm(long_tail_loader):
        data, target = data.to(device), target.to(device)
        feats = model.eval()(data, features=True)
        if train_LT_features is None:
            train_LT_features = feats.detach().cpu()
        else:
            train_LT_features = torch.cat((train_LT_features, feats.detach().cpu()), dim=0)
        train_LT_gt += target.cpu().numpy().tolist()

data_normed = LA.norm(train_LT_features.numpy(), 2, axis=-1)
train_LT_features = train_LT_features.numpy() / data_normed.reshape(-1, 1)

100%|██████████| 153/153 [00:03<00:00, 46.48it/s]


In [7]:
# extract test features

model.eval()
test_features = None
test_gt = []
with torch.no_grad():
    for data, target in tqdm(test_loader):
        data, target = data.to(device), target.to(device)
        feats = model.eval()(data, features=True)
        if test_features is None:
            test_features = feats.detach().cpu()
        else:
            test_features = torch.cat((test_features, feats.detach().cpu()), dim=0)
        test_gt += target.cpu().numpy().tolist()
        
data_normed = LA.norm(test_features.numpy(), 2, axis=-1)
test_features = test_features.numpy() / data_normed.reshape(-1, 1)

100%|██████████| 157/157 [00:02<00:00, 66.62it/s]


In [14]:
dist_tmp = torch.cdist(torch.tensor(test_features), torch.tensor(train_LT_features))
predicted = torch.argsort(dist_tmp, dim=1)
sorted_dist_test_lt = predicted.numpy()

In [15]:
print(f'SPLIT {split}')

class_acc = []
loc_gt = np.array(test_gt)
loc_argsort = sorted_dist_test_lt
train_loc_gt = np.array(train_LT_gt)
for cl in range(num_class):
    predictions = loc_argsort[loc_gt == cl, 0]
    result = (train_loc_gt[predictions] == cl).sum() / (loc_gt == cl).sum()
    class_acc.append(result)
class_acc.append(np.mean(class_acc))
print('KNN@1 ', f'{class_acc[-1]*100:.2f}')

knn=10
class_acc = []
loc_gt = np.array(test_gt)
loc_argsort = sorted_dist_test_lt
train_loc_gt = np.array(train_LT_gt)
for cl in range(num_class):
    predict_mat = np.zeros(((loc_gt == cl).sum(), num_class))
    predictions = loc_argsort[loc_gt == cl, :knn]
    for cl2 in range(num_class):
        result2 = (train_loc_gt[predictions] == cl2).sum(1)
        predict_mat[:, cl2] = result2
    result = (np.argmax(predict_mat, 1) == cl).sum() / (loc_gt == cl).sum()   
    class_acc.append(result)
class_acc.append(np.mean(class_acc))
print('KNN@10', f'{class_acc[-1]*100:.2f}')

local_data = train_LT_features
local_gt = np.array(train_LT_gt)
local_counter = Counter(local_gt)
n_samples = local_counter.most_common()[-1][-1]
np.random.seed(42)
new_data = []
new_gt = []
for cl in range(num_class):
    local_idxs = np.where(local_gt == cl)[0]
    np.random.shuffle(local_idxs)
    selected_idxs = local_idxs[:n_samples]
    new_data.append(local_data[selected_idxs])
    new_gt.extend([cl] * n_samples)

new_data = np.concatenate(new_data, 0)
fs_svm = LinearSVC(random_state=42)
fs_svm.fit(new_data, new_gt)
local_test = test_features
local_gt_test = np.array(test_gt)
predictions = fs_svm.predict(local_test)
acc = (predictions == local_gt_test).sum() / len(local_gt_test)
print('FS    ', f'{acc * 100:.2f}')

np.random.seed(42)
lt_svm = LinearSVC(random_state=42)
lt_svm.fit(local_data, local_gt)
predictions = lt_svm.predict(local_test)
acc = (predictions == local_gt_test).sum() / len(local_gt_test)
print('LT    ', f'{acc * 100:.2f}')

SPLIT 1
KNN@1  29.12
KNN@10 28.11
FS     27.02
LT     32.29
