In [1]:
import os
os.chdir('../src')

In [2]:
import datetime, sys

import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from sklearn.metrics import confusion_matrix

from data_builder.builder import build_valid_loader
from models.create_model import CustomNet
from config import _C as cfg


In [3]:
test_loader = build_test_loader(cfg)
model = CustomNet(cfg)

checkpoint = torch.load(cfg.CHECKPOINT_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
del checkpoint

In [4]:
model = model.to(cfg.DEVICE)
model.eval()

preds = []
target = []
confidence = []

test_loader = tqdm(test_loader, total=len(test_loader), desc='Valid')
for cnt, (imgs, targets, idxs) in enumerate(test_loader):

    imgs = imgs.to(cfg.DEVICE)

    with torch.no_grad():
        logits = model(imgs)
        
        softmax = torch.nn.Softmax(dim=1)(logits).detach().cpu().numpy()
        prediction_class = np.argmax(a=softmax, axis=1)
        conf = np.max(a=softmax, axis=1)

        preds.append(prediction_class)
        confidence.append(conf)
        target.append(targets.detach().cpu().numpy())

target = np.concatenate(target, axis=0)
confidence = np.concatenate(confidence, axis=0)
preds = np.concatenate(preds, axis=0)

Valid: 100%|██████████| 457/457 [02:00<00:00,  3.79it/s]


In [5]:
cm = confusion_matrix(target, preds)
print(cm)

[[ 258   16   16    9   55]
 [  26  664   28   20   46]
 [   4   17  767   46   23]
 [   5   14   80 4443   34]
 [  40   43   60   47  538]]


In [6]:
for i, val in enumerate(cm):
    print("for class {}: accuracy: {}".format(i, val[i]/sum(val)*100))

for class 0: accuracy: 72.88135593220339
for class 1: accuracy: 84.6938775510204
for class 2: accuracy: 89.49824970828472
for class 3: accuracy: 97.09353146853147
for class 4: accuracy: 73.9010989010989


In [7]:
np.sum(target==preds)/len(target)*100

91.38238114810248

In [9]:
confidence

array([0.9758552 , 0.98574173, 0.9795562 , ..., 0.98359585, 0.9695898 ,
       0.8129119 ], dtype=float32)

In [19]:
len(confidence)

368

In [21]:
len(target)

23

In [24]:
idx = np.argsort(confidence)[:5]

In [25]:
confidence[idx]

array([0.335761  , 0.3588644 , 0.4413463 , 0.44890353, 0.45404968],
      dtype=float32)