In [None]:
import os
import PIL.Image
import numpy as np
import pandas as pd
import torch
import torchvision
import cv2
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve

In [None]:
# define parameters
mydatadir='dataset/'
img_size = 224
batch_size = 20
num_workers = 10
class_names = ['no-STAS', 'STAS']
num_classes = len(class_names)

# test split
split_test = ''

# define model path
model_path = ''


In [None]:
def bootstrap_auc(y, preds, fold_size, bootstraps=100, ):
    auc_values = []

    df = pd.DataFrame(columns=['y', 'pred'])
    df.loc[:, 'y'] = y
    df.loc[:, 'pred'] = preds
    df_pos = df[df.y == 1]
    df_neg = df[df.y == 0]
    prevalence = len(df_pos) / len(df)
    for i in range(bootstraps):
        pos_sample = df_pos.sample(n=int(fold_size * prevalence), replace=True)
        neg_sample = df_neg.sample(n=int(fold_size * (1 - prevalence)), replace=True)
        y_sample = np.concatenate([pos_sample.y.values, neg_sample.y.values])
        pred_sample = np.concatenate([pos_sample.pred.values, neg_sample.pred.values])
        roc_auc = roc_auc_score(y_sample, pred_sample)
        auc_values.append(roc_auc)
    auc_values = np.array(auc_values)
    auc_values.sort()
    _lower = round(float(auc_values[int(0.025 * len(auc_values))]), 4)
    _upper = round(float(auc_values[int(0.975 * len(auc_values))]), 4)
    _middle = round(roc_auc_score(y, preds), 4)

    return _lower, _middle, _upper


In [None]:
# define dataset
class ClassificationDataset(torch.utils.data.Dataset):
    """ dataset """

    def __init__(self, datadir='', split='train', transforms=None):
        self.split = split
        self.transforms = transforms
        
        self.imdb = []
        stats = {lbl: 0 for lbl in class_names}
        
        with open(os.path.join(datadir, split + '.txt')) as ff:
            for line in ff:
                line = line.strip()
                imgname, label = line.split('t')
                label = int(label)
                imgpath = os.path.join(os.path.join(datadir, '', imgname))
                self.imdb.append({
                    'imgpath': imgpath,
                    'label': label,
                })
                stats[class_names[label]] = stats[class_names[label]] + 1
                
        print('split: %s, total image num: %d' % (split, len(self.imdb)))
        for classname in stats:
            print('    %s: %d' % (classname, stats[classname]))
    
    def __getitem__(self, index):
        # Load the image
        imgpath = self.imdb[index]["imgpath"]
        label = self.imdb[index]["label"]
        
        # read image
        img = PIL.Image.open(imgpath).convert('RGB')
        
        if self.transforms is not None:
            img = self.transforms(img)

        return img, label

    def __len__(self):
        return len(self.imdb)

In [None]:
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize([img_size, img_size]), 
    torchvision.transforms.ToTensor(),
    normalize,
])
test_dataset = ClassificationDataset(datadir=mydatadir, split=split_test, transforms=transforms_test)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True, 
    drop_last=False,
)


In [None]:
# test model
model = torch.load(model_path)
model = model.cuda()
model.eval()

criterian = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    testloss = 0.
    testacc = 0.
    predlist = np.array([])
    gtlist = np.array([])
    for idx, (inputs, labels) in (test_loader):
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        output = model(inputs)
        prob = F.softmax(output, dim=1)
        loss = criterian(output, labels)
        testloss += loss.item()
        _, predict = torch.max(output, 1)
        num_correct = (predict == labels).sum()
        testacc += num_correct.item()
        if idx == 0:
            predlist = prob.view(-1, num_classes).cpu().detach().numpy()
            gtlist = labels.cpu().detach().numpy()
        else:
            predlist = np.append(predlist, prob.view(-1, num_classes).cpu().detach().numpy(), axis=0)
            gtlist = np.append(gtlist, labels.cpu().detach().numpy())

    testloss /= len(test_loader)
    testacc /= test_dataset.__len__()
    print('Test loss: %.4f, Test acc: %.3f' % (testloss, testacc))

    for i in range(num_classes):
        gt = gtlist == i
        pred = predlist[:, i]
        fpr, tpr, _ = roc_curve(gt, pred)
        auc = bootstrap_auc(gt, pred, gt.shape[0])
        binary_pred = (pred >= 0.5).astype(np.int16)
        tn, fp, fn, tp = confusion_matrix(gt, binary_pred).ravel()
        acc = (tp + tn) / (tn + fp + fn + tp) * 100
        sen = tp / (tp + fn) * 100
        spec = tn / (tn + fp) * 100
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1_score = 2 * precision * recall / (precision + recall)
        print('class: %s, auc:%.3f[%.3f, %.3f], acc:%.1f, sensitive:%.1f, specificity:%.1f, recall:%.3f, precision:%.3f, f1_score:%.1f' % (
            class_names[i], auc[1], auc[0], auc[2], acc, sen ,spec, recall, precision, f1_score
        ))
        
        
        fig, ax = plt.subplots(figsize=(5.12, 5.12))
        ax.plot(fpr, tpr, lw=2, color='r', label='%s (AUC=%.3f[%.3f-%.3f])' % (class_names[i], auc[1], auc[0], auc[2]))
        
        plt.title('AUROC of %s' % class_names[i], fontsize=15)
        plt.plot([0, 1], [0, 1], '--', color='gray')
        plt.xlabel('1 - Specificity', fontdict={'fontsize': 15})
        plt.ylabel('Sensitivity', fontdict={'fontsize': 15})

        
        cm_show = confusion_matrix(gt, binary_pred)
        fig, ax = plt.subplots(figsize=(5.12, 5.12))
        im = ax.imshow(cm_show, interpolation='nearest', cmap=plt.cm.Blues)
        ax.figure.colorbar(im, ax=ax)
#         plt.title('Confusion matrix')
        ax.set(
            xticks=np.arange(cm_show.shape[1]),
            yticks=np.arange(cm_show.shape[0]),
            xticklabels=list(range(num_classes)), yticklabels=list(range(num_classes)),
            title='Confusion matrix',
            ylabel='Actual',
            xlabel='Predicted'
        )

        fmt = 'd'
        thresh = cm_show.mean()
        print(cm_show, thresh)
        for i in range(cm_show.shape[0]):
            for j in range(cm_show.shape[1]):
                ax.text(j, i, format(int(cm_show[i, j]) , fmt),
                    fontsize=20,
                    ha="center", va="center",
                    color="white"  if cm_show[i, j] > thresh else "black"
                )
        fig.tight_layout()

