In [None]:
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np

from pathlib import Path
import pandas as pd
import nibabel as nib

import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn.functional as F

import sklearn
import sklearn.metrics

import gradio as gr

In [None]:
dataset_dir = Path(os.getcwd()) / "Datasets" / "ADNI" / "ADNI1"
print(dataset_dir)

imgs_dir = dataset_dir / "Images"
masks_dir = dataset_dir / "SegmentationMasks"
store_dir = dataset_dir / "InputImages"
input_dir = store_dir

print(imgs_dir)
print(masks_dir)

csv_files = dataset_dir.glob("**/*.csv")
for csv_file in csv_files:
    csv_data = csv_file
    
print(csv_data)
df = pd.read_csv(csv_data)

In [None]:
import importlib
import models.ResNet as RN
_ = importlib.reload(RN)

In [None]:
label_dict = {
    "CN": 0,
    "AD": 1,
}

num_classes = len(label_dict.keys())
save_path = "model_resnet_best.pth"

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RN.ResNet18(device, num_classes)

count = 8
if count == 1:
    save_path = "model_resnet_best.pth"
else:
    save_path = "model_resnet18_8slices.pth"

model.conv1 = torch.nn.Conv2d(count, model.conv1.out_channels, kernel_size=model.conv1.kernel_size, stride=model.conv1.stride, padding=model.conv1.padding, bias=model.conv1.bias)
model.fc2 = torch.nn.Linear(in_features=model.fc2.in_features, out_features=2, bias=True)

if os.path.exists(save_path):
    print(f"found saved state at: {save_path}")
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])  

_ = model.to(device)

In [None]:
input_shapes = {
    (256, 256, 166) : 128,
}

predict_labels = {
    0: "Cognitively Normal",
    1: "Alzheimer's Disease",
}

def read_image(path):
    input_img = nib.load(path).get_fdata()
    return input_img

def min_max_normalize(image, new_min=0, new_max=1):
    min_val = np.min(image)
    max_val = np.max(image)
    normalized_image = (image - min_val) / (max_val - min_val) * (new_max - new_min) + new_min
    return normalized_image

def separate_coronal_slices_around(img, slice_number, slices_count):
    slices = []
    for slice_no in range(slice_number - slices_count, slice_number + slices_count):
        slice = np.rot90(img[:, slice_no, :], k=2)
        slice = min_max_normalize(slice)
        slices.append(slice) # coronal
    return slices

def separate_coronal_slices_around1(img, slice_number, slices_count):
    slices = []
    for slice_no in range(slice_number - slices_count, slice_number + slices_count + 1):
        slice = np.rot90(img[:, slice_no, :], k=2)
        slice = min_max_normalize(slice)
        slices.append(slice) # coronal
    return slices


def predict(file, age, sex):
    img = read_image(file.name)
    if img.shape not in input_shapes:
        return "Incorrect file dimensions, try another"
    if count == 8:
        coronal_slices = separate_coronal_slices_around(img, input_shapes[img.shape], 4)
    else:
        coronal_slices = separate_coronal_slices_around1(img, input_shapes[img.shape], 0)
    coronal_slices = torch.tensor(coronal_slices, dtype=torch.float).to(device)

    x = coronal_slices.unsqueeze(0)
    y = torch.tensor((sex, age)).unsqueeze(0).to(device)
    model.eval()
    output = model(x, y)
    return predict_labels[torch.argmax(output).item()]


demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.File(label="Upload .nii file"),
        gr.Slider(1, 100, step=1, label='Age'),
        gr.Radio(['Male', 'Female'], label='Sex', type='index')
    ],
    outputs='text',
    title="Alzheimer's Disease Detection"
)
demo.launch()