In [23]:
from dataset.cifar10 import get_cifar10_dataloaders, get_cifar10_dataloaders_sample
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
from models import model_dict

from sklearn.metrics import confusion_matrix

In [13]:
def get_teacher_name(model_path):
    """parse teacher name"""
    segments = model_path.split('/')[-2].split('_')
    if segments[0] != 'wrn':
        if 'S' not in segments[0]:
            return segments[0]
        else:
            return segments[0].split(':')[-1]
    else:
        return segments[0] + '_' + segments[1] + '_' + segments[2]


def load_teacher(model_path, n_cls):
    print('==> loading teacher model')
    model_t = get_teacher_name(model_path)
    model = model_dict[model_t](num_classes=n_cls)
    model.load_state_dict(torch.load(model_path)['model'])
    print('==> done')
    return model

In [26]:
path = r'save/student_model/S:resnet110_T:resnet110_cifar10_kd_r:0.1_a:0.9_b:0.0_maskhead9_Resample/resnet110_best.pth'
n_cls = 10
model = load_teacher(path, n_cls).cuda()
train_loader, val_loader, n_data = get_cifar10_dataloaders(batch_size=64,
                                                                num_workers=8,
                                                                is_instance=True,
                                                               train_rule=None)

==> loading teacher model
==> done
Files already downloaded and verified
Files already downloaded and verified
train cls num list:
[5000, 2997, 1796, 1077, 645, 387, 232, 139, 83, 50]


In [27]:
model.eval()

all_preds = []
all_targets = []
with torch.no_grad():
    for idx, (input, target) in enumerate(val_loader):

        input = input.float()
        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()

        # compute output
        output = model(input)
        _, pred = torch.max(output, 1)
        all_preds.extend(pred.cpu().numpy())
        all_targets.extend(target.cpu().numpy())


In [29]:
cf = confusion_matrix(all_targets, all_preds).astype(float)
cls_cnt = cf.sum(axis=1)
cls_hit = np.diag(cf)
cls_acc = cls_hit / cls_cnt

out_cls_acc = 'Class Accuracy: %s' % (
    (np.array2string(cls_acc, separator=',', formatter={'float_kind': lambda x: "%.3f" % x})))

In [30]:
out_cls_acc

'Class Accuracy: [0.817,0.474,0.825,0.608,0.748,0.494,0.542,0.465,0.253,0.976]'

In [31]:
cf

array([[817.,   0.,  13.,   1.,   3.,   0.,   0.,   0.,   1., 165.],
       [  1., 474.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 525.],
       [ 30.,   0., 825.,  27.,  25.,   7.,   5.,   0.,   0.,  81.],
       [ 12.,   0.,  70., 608.,  25.,  23.,   4.,   1.,   1., 256.],
       [ 12.,   0.,  74.,  47., 748.,   7.,   2.,   3.,   0., 107.],
       [  5.,   0.,  72., 213.,  22., 494.,   1.,   9.,   0., 184.],
       [  6.,   0., 130., 109.,  18.,   3., 542.,   0.,   0., 192.],
       [ 30.,   1.,  60.,  78.,  85.,  29.,   0., 465.,   0., 252.],
       [167.,   1.,   7.,   4.,   1.,   0.,   1.,   0., 253., 566.],
       [ 15.,   4.,   2.,   1.,   1.,   0.,   0.,   0.,   1., 976.]])

In [32]:
cf.sum(axis=1)

array([1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000., 1000.,
       1000.])