# Model Testing

In [None]:
import sys
sys.path.append('../')

In [None]:
import os
import torch
import pandas as pd

from tqdm import tqdm

from sklearn.metrics import auc, accuracy_score, f1_score, recall_score, precision_score
from sklearn.metrics import confusion_matrix, roc_curve

from monai.networks.nets import DenseNet
from torch.utils.data import DataLoader

from src.data.dataset import LABELS, BrainMriDataset
from src.data.transforms import Transforms
from src.utils.model import load_model
from src.utils.visualisation import plot_confusion_matrix, plot_roc_curve

In [None]:
print('PyTorch Version:', torch.__version__)
print('Is CUDA Available:', torch.cuda.is_available())

In [None]:
DEVICE      = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_CSV = '../data/processed/dataset_nifti.csv'
INPUT_PATH  = '../models/'
LOGS_PATH   = '../logs/'
NUM_WORKERS = 8
BATCH_SIZE  = 16

In [None]:
assert os.path.exists(DATASET_CSV)
assert os.path.exists(INPUT_PATH)

In [None]:
dataset = pd.read_csv(DATASET_CSV)
test_data = dataset[dataset['split'] == 'test']

In [None]:
test_transform = Transforms.get_data_loading()

test_dataset = BrainMriDataset(
    dataset_df=test_data,
    transform=test_transform
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    shuffle=False
)

In [None]:
model = DenseNet(spatial_dims=3, in_channels=1, out_channels=1, dropout_prob=0.2).to(DEVICE)
load_model(model, INPUT_PATH, DEVICE)

In [None]:
y_true = []
y_pred = []
model.eval()

with torch.no_grad():
    for step, batch in tqdm(enumerate(test_loader), 'Testing', len(test_loader)):
        with torch.autocast(DEVICE):
            labels = batch['label'].to(DEVICE).float().unsqueeze(1)
            images = batch['image'].to(DEVICE)

        y_pred_prob = torch.sigmoid(model(images))
        y_pred_label = (y_pred_prob > 0.5).float()

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(y_pred_label.cpu().numpy())

In [None]:
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
fscore = f1_score(y_true, y_pred)

false_positive_rate, true_positive_rate, _ = roc_curve(y_true, y_pred)
roc_auc = auc(false_positive_rate, true_positive_rate)

print('AUC of the ROC Curve: {}'.format(roc_auc))
print('Accuracy-Score: {}'.format(accuracy))
print('Precision-Score: {}'.format(precision))
print('Recall-Score: {}'.format(recall))
print('F1-Score: {}'.format(fscore))

plot_confusion_matrix(confusion_matrix(y_true, y_pred), class_names=LABELS.keys(), figsize=(4, 4))
plot_roc_curve(false_positive_rate, true_positive_rate, figsize=(6, 4))