In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2
    
# Accessing moduels
import sys,os
sys.path.append(os.path.realpath('../Modules'))

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from dataloader.transforms import Transforms

from model.model import ViT
from model.train import Trainer

from matplotlib import pyplot as plt

In [None]:
import numpy as np

import torch
# import torch.nn as nn
# from torchvision.transforms import Compose, Resize
from atlas.atlas import AAL3Channels

# Dataset and Dataloader Setup

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

transforms = Transforms(image_size=(384, 384), p=0.5)

train_ds = ADNI3Channels("../Data/Training/", transforms=transforms.eval())
valid_ds = ADNI3Channels("../Data/Validation/", transforms=transforms.eval())
test_ds = ADNI3Channels("../Data/Test/", transforms=transforms.eval())

In [None]:
image, label = train_ds[0]

print("Image shape:", image.shape)
print("Label:", id2label[label.item()], "\n")

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

print("Min pixel value =", image.min().item())
print("Max pixel value =", image.max().item())

In [None]:
kwargs = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
         }

train_dataloader = ADNILoader(**kwargs).train_dataloader()
valid_dataloader= ADNILoader(**kwargs).validation_dataloader()
test_dataloader = ADNILoader(**kwargs).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Loading Model

In [None]:
model = ViT(
    pretrained=True,
    model_name="google/vit-base-patch32-384",
    device="cuda:0"
)

model.load_best_state_file("acc", "../ViT/Best models/", "ViT_Pretrained")

kwargs = {
    "epochs": 100,
    "model":model,
    "train_dataloader": train_dataloader,
    "valid_dataloader": valid_dataloader,
    "test_dataloader": test_dataloader,
}

trainer = Trainer(**kwargs)

In [None]:
# trainer.test(trainer.train_dataloader)
# trainer.test(trainer.valid_dataloader)
trainer.test(trainer.test_dataloader)

# Getting Attention Map

In [None]:
from model.attention import get_attention_map

idx = 10
image, label = test_ds[idx]
# It's possible to modify dataloader class to get batch size of 1 instead of unsqueeze
logits, attention, _ = model(image.unsqueeze(0).to(model.device))
print("Label:", id2label[label.item()])
print("Prediction:", id2label[torch.argmax(logits).item()])
img, att_map = get_attention_map(image, attention, model.device, rotate=True)

# Mask
image = image.permute(1, 2, 0).numpy()
image = np.rot90(image)
img = img * np.where(image>0, 1, 0)
img = img [:, :, 2]

# Normalize
img = (img - img.min()) / (img.max() - img.min())

fig, ax = plt.subplots(dpi=300)
im = ax.imshow(img, cmap='plasma')
im = ax.imshow(img)
ax.axis("off");
cbar = fig.colorbar(im);