In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torchvision.transforms import Compose, Resize

from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification

import cv2
from matplotlib import pyplot as plt
from matplotlib import patches

from collections import Counter

import pandas as pd

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from atlas.atlas import ReadersAtlas3Channels, AAL3Channels

from utils.report import sklearn_classification_report

# Dataset and Dataloader Setup

In [None]:
image_size_down = (384, 384)
image_size_up = (950, 570)

In [None]:
train_ds = ADNI3Channels("../Data/Training/", transforms=None, rotate=True)
valid_ds = ADNI3Channels("../Data/Validation/", transforms=None, rotate=True)
test_ds = ADNI3Channels("../Data/Test/", transforms=None, rotate=True)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");
    # print(image[i, :, :].min(), image[i, :, :].max())

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 1
valid_batch_size = 1
test_batch_size = 1

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': False,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': False,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Atlas

In [None]:
atlas_data, atlas_labels = ReadersAtlas3Channels(aal_dir='../Data/AAL/Resized_AAL.nii',
                                                 labels_dir='../Data/AAL/ROI_MNI_V4.txt',
                                                 rotate=True).get_data()

print(atlas_data.shape, '\n')
print(atlas_labels, '\n')
        
fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(atlas_data[i, :, :])
    axes[i].axis("off");

# Model

In [None]:
class ViT(nn.Module):
    def __init__(self, atlas_data, atlas_labels, num_labels=3):
        super(ViT, self).__init__()
        self.resize_down = Resize((384, 384))
        self.resize_up = Resize((570, 950))
        self.id2label = {0: 'CN', 1: 'MCI', 2: 'AD'}
        self.atlas_data = atlas_data
        self.atlas_labels = atlas_labels
        self.atlas_id2label = {value: key for key, value in atlas_labels.items()}
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384',
                                                             output_attentions=True,
                                                             output_hidden_states=True,
                                                             num_labels=num_labels,
                                                             hidden_dropout_prob=0.1,
                                                             # attention_probs_dropout_prob=0.1,
                                                             ignore_mismatched_sizes=True
                                                            )
        
    def forward(self, x):
        x = torch.rot90(x, -1, [2, 3])
        x = self.resize_down(x)
        outputs = self.vit(x)
        return outputs.logits, outputs.attentions
    
    def infer(self, x, show_input=False, show_overlay=False, show_atlas=False):
        logits, attention = self.forward(x)
        pred = self.id2label[logits.argmax(1).item()]
        
        # Getting attention map
        # code inspired by "https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb"
        image_size = image.shape[-1]

        att_mat = torch.stack(attention).squeeze(1)
        att_mat = torch.mean(att_mat, dim=1)

        residual_att = torch.eye(att_mat.size(1)).to(device)
        aug_att_mat = att_mat + residual_att
        aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

        joint_attentions = torch.zeros(aug_att_mat.size())
        joint_attentions[0] = aug_att_mat[0]

        for n in range(1, aug_att_mat.size(0)):
            joint_attentions[n] = torch.matmul(aug_att_mat[n].to(device), joint_attentions[n-1].to(device))

        v = joint_attentions[-1]
        grid_size = int(np.sqrt(aug_att_mat.size(-1)))
        heatmap = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
        heatmap = cv2.resize(heatmap / heatmap.max(), (image_size, image_size))[..., np.newaxis]
        # Duplicate heatmap to form a 3 channel image
        heatmap = np.concatenate([heatmap]*3, axis=2)           
        
        # Getting most important region
        heatmap = cv2.resize(heatmap, image_size_up)
        heatmap = torch.tensor(heatmap).permute(2, 0, 1)
        
        # Extracting the most important region
        atlas_mask = torch.where(atlas_data>0, 1, 0)
        atlas_masked_heatmap = atlas_mask * heatmap
        final = torch.where(atlas_masked_heatmap==atlas_masked_heatmap.max(), 1, 0)
        region = final * atlas_data
        region = int(region.max().item())
        region = self.atlas_id2label[region]
            
        if show_input:
            fig, axes = plt.subplots(ncols=3, figsize=(12, 2), dpi=300)
            axes[0].imshow(x.squeeze()[0, :, :])
            axes[0].axis('off')
            axes[1].imshow(x.squeeze()[1, :, :])
            axes[1].axis('off')
            axes[2].imshow(x.squeeze()[2, :, :])
            axes[2].axis('off')
        
        if show_overlay:
            # Overlaying heatmap on image
            image_mask = torch.where(x.squeeze()>0, 1, 0)
            image_masked_heatmap = image_mask * heatmap
            overlay =  image_masked_heatmap * 2 + x.squeeze()
        
            fig, axes = plt.subplots(ncols=3, figsize=(12, 2.5), dpi=300)
            axes[0].imshow(overlay[0, :, :])
            axes[0].axis('off')
            axes[1].imshow(overlay[1, :, :])
            axes[1].axis('off')
            axes[2].imshow(overlay[2, :, :])
            axes[2].axis('off')
            
            final = torch.where(atlas_masked_heatmap>atlas_masked_heatmap.max()*0.98, 1, 0)
            for i in range(3):
                m = final[i, :, :].nonzero()
                if m.numel() != 0:
                    for mm in m :
                        mm -= 10
                        rect = patches.Rectangle([mm[1], mm[0]], 20, 20, linewidth=0.1, edgecolor='r', facecolor='none')
                        axes[i].add_patch(rect)
            
            fig.suptitle(f"Prediction: {pred}\n Most Important Region: {region}")

        if show_atlas:
            fig, axes = plt.subplots(ncols=3, figsize=(12, 2.5), dpi=300)
            axes[0].imshow(atlas_masked_heatmap[0, :, :])
            axes[0].axis('off')
            axes[1].imshow(atlas_masked_heatmap[1, :, :])
            axes[1].axis('off')
            axes[2].imshow(atlas_masked_heatmap[2, :, :])
            axes[2].axis('off')
            
        return pred, region

In [None]:
device = torch.device('cpu') 
model = ViT(num_labels=3,
            atlas_data=atlas_data,
            atlas_labels=atlas_labels).to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384',
                                                        do_resize=False,
                                                        do_normalize=False)

model.load_state_dict(torch.load("../ViT/Best models/ViT_Pretrained_acc.pt"))

In [None]:
def predict(model, dataloader, device):
    y_true = []
    y_pred = []
    
    model.eval()
    with torch.no_grad():
        for step, (x, y) in enumerate(dataloader):
            x = np.split(np.array(x), dataloader.batch_size)
            for i in range(len(x)):
                x[i] = np.squeeze(x[i])
            x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
            x, y  = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = logits.argmax(1)
        
            y_pred.append(preds.cpu().numpy())
            y_true.append(y.cpu().numpy())

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)
    
    return y_true, y_pred

y_true, y_pred = predict(model, test_dataloader, device)
sklearn_classification_report(y_true, y_pred)

# Inference


In [None]:
test_iter = iter(test_dataloader)

In [None]:
x, y = next(test_iter)
pred, region = model.infer(x=x,
                           show_input=True,
                           show_overlay=True,
                           show_atlas=True
                          )
print('Label:', id2label[y.item()])
print('Prediction:', pred)
print('Most Important Region:', region)

# Showing the Most Important Region on Atlas

In [None]:
test = torch.where(atlas_data==atlas_labels[region], atlas_data*15, atlas_data)

fig, axes = plt.subplots(ncols=3, figsize=(12, 2), dpi=300)

axes[0].imshow(test[0, :, :])
axes[0].axis('off')

axes[1].imshow(test[1, :, :])
axes[1].axis('off')

axes[2].imshow(test[2, :, :])
axes[2].axis('off');

# Comparing the Most Important Regions with Readers' Suggestions

In [None]:
readers_ds = ADNI3Channels("../Data/Readers/", transforms=None, rotate=True)

In [None]:
idx = 0
image, label = readers_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of readers samples:", len(readers_ds))

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");

In [None]:
test_batch_size = 1

hparams = {'train_ds': None,
           'valid_ds': None,
           'test_ds': readers_ds,
           'train_batch_size': None,
           'valid_batch_size': None,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': False,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': False,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(test_dataloader))
print(batch[0].shape)
print(batch[1].shape)

In [None]:
# Sample IDs that are common between ours and readers' test dataset
#1111, 2789, 1097, 2694, 2783, 886, 1057

i = 6

print(readers_ds.files_dir[i])

x, y = readers_ds[i]
x = x.unsqueeze(0)

pred, region = model.infer(x=x,
                           show_input=True,
                           show_overlay=True,
                           show_atlas=True
                          )
print('Label:', id2label[y.item()])
print('Prediction:', pred)
print('Most Important Region:', region)