In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2
    
# Accessing moduels
import sys,os
sys.path.append(os.path.realpath('../Modules'))

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from dataloader.transforms import Transforms

from model.model import ViT
from model.train import Trainer

from matplotlib import pyplot as plt
from utils.image import save_fig

In [None]:
import numpy as np

import torch
import torch.nn as nn
import pandas as pd

# Dataset and Dataloader Setup

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

transforms = Transforms(image_size=(384, 384), p=0.5)

train_ds = ADNI3Channels("../Data/Training/", transforms=transforms.eval())
valid_ds = ADNI3Channels("../Data/Validation/", transforms=transforms.eval())
test_ds = ADNI3Channels("../Data/Test/", transforms=transforms.eval())

In [None]:
image, label = train_ds[0]

print("Image shape:", image.shape)
print("Label:", id2label[label.item()], "\n")

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

print("Min pixel value =", image.min().item())
print("Max pixel value =", image.max().item())

In [None]:
kwargs = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
         }

train_dataloader = ADNILoader(**kwargs).train_dataloader()
valid_dataloader= ADNILoader(**kwargs).validation_dataloader()
test_dataloader = ADNILoader(**kwargs).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Atlas

In [None]:
from atlas.atlas import AAL3Channels

atlas_data, atlas_labels = AAL3Channels(aal_dir='../Data/AAL/Resized_AAL.nii',
                                                 labels_dir='../Data/AAL/ROI_MNI_V4.txt',
                                                 rotate=True).get_data()

print(atlas_data.shape, '\n')
print(len(atlas_labels), '\n')

print(atlas_data.min(), atlas_data.max())
        
fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(atlas_data[i, :, :])
    axes[i].axis("off");

# Loading Model

In [None]:
model = ViT(
    pretrained=True,
    model_name="google/vit-base-patch32-384",
    device="cuda:0"
)

model.load_best_state_file("acc", "../ViT/Best models/", "ViT_Pretrained")

kwargs = {
    "epochs": 100,
    "model":model,
    "train_dataloader": train_dataloader,
    "valid_dataloader": valid_dataloader,
    "test_dataloader": test_dataloader,
}

trainer = Trainer(**kwargs)

In [None]:
# trainer.test(trainer.train_dataloader)
# trainer.test(trainer.valid_dataloader)
trainer.test(trainer.test_dataloader)

# Inference

In [None]:
# CN: 14
# MCI: 58
# AD: 36

x, y = test_ds[36]

pred, region = model.infer(x=x,
                           atlas_data=atlas_data,
                           atlas_labels=atlas_labels,
                           show_overlaid_attention_map=True,
                           show_patches=True,
                           show_attention_map=True,
                           show_input=True)

print('Label:', id2label[y.item()])
print('Prediction:', pred)
print('Most Important Region:', region)

# Showing the Most Important Region on Atlas

In [None]:
test = torch.where(atlas_data==atlas_labels[region], atlas_data*15, atlas_data)

fig, axes = plt.subplots(ncols=3, figsize=(12, 2), dpi=300)

axes[0].imshow(test[0, :, :])
axes[0].axis('off')

axes[1].imshow(test[1, :, :])
axes[1].axis('off')

axes[2].imshow(test[2, :, :])
axes[2].axis('off');