In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2
    
# Accessing moduels
import sys,os
sys.path.append(os.path.realpath('../Modules'))

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from dataloader.transforms import Transforms

from model.model import ViT
from model.train import Trainer

from matplotlib import pyplot as plt
from utils.image import save_fig

In [None]:
import matplotlib.colors as mcolors
from ast import literal_eval

# Dataset and Dataloader Setup

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

transforms = Transforms(image_size=(384, 384), p=0.5)

train_ds = ADNI3Channels("../Data/Training/", transforms=transforms.eval())
valid_ds = ADNI3Channels("../Data/Validation/", transforms=transforms.eval())
test_ds = ADNI3Channels("../Data/Test/", transforms=transforms.eval())

In [None]:
image, label = train_ds[0]

print("Image shape:", image.shape)
print("Label:", id2label[label.item()], "\n")

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

print("Min pixel value =", image.min().item())
print("Max pixel value =", image.max().item())

In [None]:
kwargs = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
         }

train_dataloader = ADNILoader(**kwargs).train_dataloader()
valid_dataloader= ADNILoader(**kwargs).validation_dataloader()
test_dataloader = ADNILoader(**kwargs).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Atlas

In [None]:
from atlas.atlas import AAL3Channels

atlas_data, atlas_labels = AAL3Channels(aal_dir='../Data/AAL/Resized_AAL.nii',
                                                 labels_dir='../Data/AAL/ROI_MNI_V4.txt',
                                                 rotate=True).get_data()

print(atlas_data.shape, '\n')
print(len(atlas_labels), '\n')
print(atlas_data.min(), atlas_data.max())
        
fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(atlas_data[i, :, :])
    axes[i].axis("off");

# Loading Model

In [None]:
model = ViT(
    pretrained=True,
    model_name="google/vit-base-patch32-384",
    device="cuda:0"
)

model.load_best_state_file("acc", "../ViT/Best models/", "ViT_Pretrained")

kwargs = {
    "epochs": 100,
    "model":model,
    "train_dataloader": train_dataloader,
    "valid_dataloader": valid_dataloader,
    "test_dataloader": test_dataloader,
}

trainer = Trainer(**kwargs)

In [None]:
# trainer.test(trainer.train_dataloader)
# trainer.test(trainer.valid_dataloader)
trainer.test(trainer.test_dataloader)

# Inference

In [None]:
x, y = test_ds[34]

pred, region = model.infer(x=x,
                           atlas_data=atlas_data,
                           atlas_labels=atlas_labels,
                           show_overlaid_attention_map=True,
                           show_patches=True,
                           show_attention_map=True,
                           show_input=True)

id2label = {0: 'CN', 1: 'MCI', 2: 'AD'}
print('Label:', id2label[y.item()])
print('Prediction:', pred)
print('Most Important Region:', region)

# Regions' Importance

In [None]:
from torch.utils.data import ConcatDataset, DataLoader

In [None]:
all_ds = ConcatDataset([train_ds, valid_ds, test_ds])
print(len(all_ds))

In [None]:
all_dataloader = DataLoader(all_ds,
                            batch_size=5,
                            shuffle=False,
                            num_workers=20,
                            drop_last=False,
                           )

In [None]:
trainer.test(all_dataloader)

In [None]:
import numpy as np

def extract_regions_importance(dataset, label=None, only_non_zeros=False):
    id2label = {0: 'CN', 1: 'MCI', 2: 'AD'}
    
    # Region counter
    region_cnt = {key: 0 for key in atlas_labels.keys()}
    
    for x, y in dataset:
        pred, region = model.infer(x=x,
                               atlas_data=atlas_data,
                               atlas_labels=atlas_labels,
                               show_overlaid_attention_map=False,
                               show_patches=False,
                               show_attention_map=False,
                               show_input=False)
    
        if label:
            if(id2label[y.item()] == pred) and (label == pred):
                region_cnt[region] += 1

        else:
            if(id2label[y.item()] == pred):
                region_cnt[region] += 1

    # Normalization
    region_cnt_max = region_cnt[max(region_cnt, key=region_cnt.get)]
    region_cnt_min = region_cnt[min(region_cnt, key=region_cnt.get)]

    for region in region_cnt:
        region_cnt[region] = (region_cnt[region] - region_cnt_min) / (region_cnt_max - region_cnt_min)

    # Checking for only_non_zeros
    if only_non_zeros:
        # Sort
        region_cnt = dict(sorted(region_cnt.items(), key=lambda item: item[1], reverse=True))
        # Remove zero values
        region_cnt = {key: value for key, value in region_cnt.items() if value != 0}

    return region_cnt
    

def plot_regions_importance(region_cnts, key):
    my_cmap = plt.get_cmap('viridis')
    colors = my_cmap(np.linspace(0, 1, 8))
    
    fig, ax = plt.subplots(
        figsize=(4, 2),
        dpi=300,
        layout="tight"
    )
    
    ax.bar(
        x=region_cnts.keys(),
        height=region_cnts.values(),
        width=0.5,
        color=colors,
    )

    ax.tick_params(axis='x', labelfontfamily="serif", labelrotation=90, labelsize=4)
    ax.tick_params(axis='y', labelfontfamily="serif", labelsize=4)

    ax.set_xlabel("Regions", fontname="serif", fontsize="xx-small",)
    ax.set_ylabel("Importance", fontname="serif", fontsize="xx-small",)

    save_fig(key, fig)

In [None]:
# Region importance for each dataset
region_cnts = {}
region_cnts["Training"] = extract_regions_importance(train_ds, only_non_zeros=True)
region_cnts["Validation"] = extract_regions_importance(valid_ds, only_non_zeros=True)
region_cnts["Test"] = extract_regions_importance(test_ds, only_non_zeros=True)

plot_regions_importance(region_cnts["Training"], "Training")
plot_regions_importance(region_cnts["Validation"], "Validation")
plot_regions_importance(region_cnts["Test"], "Test")

In [None]:
# Overall region importance and for each label
region_cnts = {}
region_cnts["CN"] = extract_regions_importance(all_ds, "CN", only_non_zeros=True)
region_cnts["MCI"] = extract_regions_importance(all_ds, "MCI", only_non_zeros=True)
region_cnts["AD"] = extract_regions_importance(all_ds, "AD", only_non_zeros=True)
region_cnts["Overall"] = extract_regions_importance(all_ds, only_non_zeros=True)

plot_regions_importance(region_cnts["CN"], "CN")
plot_regions_importance(region_cnts["MCI"], "MCI")
plot_regions_importance(region_cnts["AD"], "AD")
plot_regions_importance(region_cnts["Overall"], "Overall")

# Regions' Importance heatmap

In [None]:
def extract_regions_heatmap(dataset, label=None):
    id2label = {0: 'CN', 1: 'MCI', 2: 'AD'}
    
    att_maps = []
    
    for x, y in dataset:
        pred, region, att_map = model.infer(x=x,
                                            atlas_data=atlas_data,
                                            atlas_labels=atlas_labels,
                                            show_overlaid_attention_map=False,
                                            show_patches=False,
                                            show_attention_map=False,
                                            show_input=False,
                                            return_att_map=True)
    
        if label:
            if(id2label[y.item()] == pred) and (label == pred):
                att_maps.append(att_map)

        else:
            if(id2label[y.item()] == pred):
                att_maps.append(att_map)

    mean_att_map = sum(att_maps)/len(att_maps)

    # Normalization
    mean_att_map = (mean_att_map - mean_att_map.min()) / (mean_att_map.max() - mean_att_map.min())

    return mean_att_map

def plot_regions_heatmap(mean_att_map, key):
    fig, ax = plt.subplots(
        figsize=(3, 2),
        dpi=300,
        layout="tight"
    )
    
    im = ax.imshow(
        mean_att_map[2],
        vmin=0,
        vmax=1
    )

    cbar = fig.colorbar(im, ax=ax, shrink=0.76)
    cbar.ax.tick_params(labelsize="small")
    ax.axis("off")
    
    save_fig(f"{key}_heatmap", fig)

In [None]:
# Region heatmaps for each dataset
mean_att_maps = {}
mean_att_maps["Training"] = extract_regions_heatmap(train_ds)
mean_att_maps["Validation"] = extract_regions_heatmap(valid_ds)
mean_att_maps["Test"] = extract_regions_heatmap(test_ds)

plot_regions_heatmap(mean_att_maps["Training"], "Training")
plot_regions_heatmap(mean_att_maps["Validation"], "Validation")
plot_regions_heatmap(mean_att_maps["Test"], "Test")

In [None]:
mean_att_maps = {}
mean_att_maps["CN"] = extract_regions_heatmap(all_ds, "CN")
mean_att_maps["MCI"] = extract_regions_heatmap(all_ds, "MCI")
mean_att_maps["AD"] = extract_regions_heatmap(all_ds, "AD")
mean_att_maps["Overall"] = extract_regions_heatmap(all_ds)

plot_regions_heatmap(mean_att_maps["CN"], "CN")
plot_regions_heatmap(mean_att_maps["MCI"], "MCI")
plot_regions_heatmap(mean_att_maps["AD"], "AD")
plot_regions_heatmap(mean_att_maps["Overall"], "Overall")