In [6]:
import random
import time
import warnings
import tqdm
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import models
from utils import *
from dataset.imbalance_cifar import ImbalanceCIFAR10, ImbalanceCIFAR100
from sklearn.manifold import TSNE
from matplotlib import cm
%matplotlib inline

%load_ext autoreload
%autoreload 2

model_path = 'checkpoint/cifar10_resnet32_Focal_None_exp_0.01_stage1/ckpt.best.pth.tar'
model_path = 'checkpoint/cifar10_resnet32_LDAM_None_exp_0.01_stage1/ckpt.best.pth.tar'
model_path = 'checkpoint/cifar10_resnet32_LDAM_Resample_exp_0.01_stage2_freeze/ckpt.best.pth.tar'
# model_path = 'checkpoint/cifar10_resnet32_LDAM_DRW_exp_0.01_stage2/ckpt.best.pth.tar'


# model_path = 'checkpoint/cifar10_resnet32_LDAM_None_exp_0.01_stage1/ckpt.pth.tar'
# model_path = 'checkpoint/cifar10_resnet32_LDAM_Reweight_exp_0.01_stage2_freeze/ckpt.pth.tar'
model_path = "checkpoint/cifar10_resnet32_LDAM_DRW_exp_0.01_stage2_upperbound_freeze/ckpt.best.pth.tar"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
use_norm = True if 'LDAM' in model_path else False
model = models.__dict__['resnet32'](num_classes=10, use_norm=use_norm)

checkpoint = torch.load(model_path, map_location=torch.device(f'cuda:0'))
model.load_state_dict(checkpoint['state_dict'])


model.cuda()

ResNet_s(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affin

In [8]:
mean = [0.4914, 0.4822, 0.4465] 
std = [0.2023, 0.1994, 0.2010]
    
transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

train_dataset = ImbalanceCIFAR10(
            root='/media/data', imb_type='exp', imb_factor=0.01,
            rand_number=0, train=True, download=True, transform=transform_val)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=False,
                                                 num_workers=16, pin_memory=True)


val_dataset = datasets.CIFAR10(root='/media/data',
                                       train=False, download=True, transform=transform_val)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False,
                                                 num_workers=16, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
model.eval()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
feature_extractor.eval()

all_features = []
all_labels = []
with torch.no_grad():
    for inputs, target in tqdm.tqdm(train_loader):
        inputs = inputs.cuda()
        target = target.cuda()

        output = feature_extractor(inputs)
        output = output.view(output.size(0),-1)
        
        
        for i in range(output.size(0)):  
            all_features.append(output[i].detach().cpu().numpy())
            all_labels.append(target[i].detach().cpu().numpy())
            
all_features = np.array(all_features)
all_labels = np.array(all_labels)


100%|██████████| 125/125 [00:01<00:00, 64.92it/s] 


In [10]:
def get_centroids(feats_, labels_):
    centroids = []        
    for i in np.unique(labels_):
        centroids.append(np.mean(feats_[labels_==i], axis=0))
    return np.stack(centroids)

featmean = all_features.mean(axis=0)

# Get cl2n centorids
cl2n_feats = torch.Tensor(all_features.copy())
cl2n_feats = cl2n_feats - torch.Tensor(featmean)
norm_cl2n = torch.norm(cl2n_feats, 2, 1, keepdim=True)
cl2n_feats = cl2n_feats / norm_cl2n
cl2n_centers = get_centroids(cl2n_feats.numpy(), all_labels)

In [11]:
def l2_similarity(A, B):
    # input A: [bs, fd] (batch_size x feat_dim)
    # input B: [nC, fd] (num_classes x feat_dim)
    feat_dim = A.size(1)

    AB = torch.mm(A, B.t())
    AA = (A**2).sum(dim=1, keepdim=True)
    BB = (B**2).sum(dim=1, keepdim=True)
    dist = AA + BB.t() - 2*AB

    return -dist

def cos_similarity(A, B):
    feat_dim = A.size(1)
    AB = torch.mm(A, B.t())
    AB = AB / feat_dim
    return AB

In [12]:
all_features = []
all_labels = []
with torch.no_grad():
    for inputs, target in tqdm.tqdm(val_loader):
        inputs = inputs.cuda()
        target = target.cuda()

        output = feature_extractor(inputs)
        output = output.view(output.size(0),-1).detach().cpu()
        
        output -= featmean
        norm_x = torch.norm(output, 2, 1, keepdim=True)
        output = output / norm_x
        
        for i in range(output.size(0)):  
            all_features.append(output[i].numpy())
            all_labels.append(target[i].detach().cpu().numpy())

all_features = np.array(all_features)
all_labels = np.array(all_labels)

100%|██████████| 100/100 [00:01<00:00, 72.77it/s]


In [13]:
dists = l2_similarity(torch.from_numpy(all_features), torch.from_numpy(cl2n_centers))
ncm_pred = np.array(dists.argmax(1))
ncm_acc = (ncm_pred==all_labels).sum()/len(all_labels)

cf = confusion_matrix(all_labels, ncm_pred).astype(float)
cls_cnt = cf.sum(axis=1)
cls_hit = np.diag(cf)
cls_acc = cls_hit / cls_cnt
print(f"NCM acc {ncm_acc*100:.6}%")
print(f"Class accuracy {cls_acc*100}")


NCM acc 90.91%
Class accuracy [87.6 94.9 85.  84.1 91.  88.  95.5 94.  94.  95. ]
