# Inference Demo

In [None]:
import sys
sys.path.append('../')

In [None]:
import torch

from tkinter import Tk, filedialog
from monai.networks.nets import DenseNet
from src.data.transforms import Transforms
from src.utils.model import load_model
from src.utils.visualisation import plot_scan_central_slices

In [None]:
print('PyTorch Version:', torch.__version__)
print('Is CUDA Available:', torch.cuda.is_available())

In [None]:
DEVICE       = 'cuda' if torch.cuda.is_available() else 'cpu'
INPUT_PATH   = '../models/'
DROPOUT_PROB = 0.2

In [None]:
root = Tk()
root.withdraw()
scan_path = filedialog.askopenfilename(
    title='Select a Brain MRI Scan',
    filetypes=[('NIfTI', '*.nii.gz')]
)

In [None]:
model = DenseNet(spatial_dims=3, in_channels=1, out_channels=1, dropout_prob=DROPOUT_PROB).to(DEVICE)
load_model(model, INPUT_PATH, DEVICE)

In [None]:
model.eval()

data = { 'image': scan_path }
inference_transforms = Transforms.get_data_loading()
scan = inference_transforms(data)['image'].unsqueeze(0).to(DEVICE)

with torch.no_grad():
    y_pred_prob = torch.sigmoid(model(scan))
    y_pred_label = (y_pred_prob > 0.5).float()
    label = 'AD' if y_pred_label.cpu().numpy()[0][0] else 'CN'

scan = scan.squeeze(0).squeeze(0).cpu().numpy()

In [None]:
description = 'Predicted [%s]' % label
plot_scan_central_slices(scan.shape, scan, description, figsize=(9, 6), padding=16)