In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torchvision.transforms import Compose, Resize

from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification

import cv2
from matplotlib import pyplot as plt

from collections import Counter

import pandas as pd

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from atlas.atlas import ReadersAtlas3Channels, AAL3Channels

from utils.report import sklearn_classification_report

# Dataset and Dataloader Setup

In [None]:
image_size_down = (384, 384)
image_size_up = (950, 570)

train_transforms = Compose([Resize(size=image_size_down)])
valid_transforms = Compose([Resize(size=image_size_down)])
test_transforms = Compose([Resize(size=image_size_down)])

In [None]:
train_ds = ADNI3Channels("../Data/Training/", transforms=train_transforms)
valid_ds = ADNI3Channels("../Data/Validation/", transforms=valid_transforms)
test_ds = ADNI3Channels("../Data/Test/", transforms=test_transforms)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");
    # print(image[i, :, :].min(), image[i, :, :].max())

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 1
valid_batch_size = 1
test_batch_size = 1

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': False,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': False,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Model

In [None]:
class ViT(nn.Module):
    def __init__(self, num_labels=3):
        super(ViT, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384',
                                                             output_attentions=True,
                                                             output_hidden_states=True,
                                                             num_labels=num_labels,
                                                             hidden_dropout_prob=0.1,
                                                             # attention_probs_dropout_prob=0.1,
                                                             ignore_mismatched_sizes=True
                                                            )
        
    def forward(self, x):
        outputs = self.vit(x)
        return outputs.logits, outputs.attentions

In [None]:
device = torch.device('cpu') 
model = ViT(num_labels=3).to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384',
                                                        do_resize=False,
                                                        do_normalize=False)

model.load_state_dict(torch.load("../ViT/Best models/ViT_Pretrained_acc.pt"))

In [None]:
def predict(model, dataloader, device):
    y_true = []
    y_pred = []
    
    model.eval()
    with torch.no_grad():
        for step, (x, y) in enumerate(dataloader):
            x = np.split(np.array(x), dataloader.batch_size)
            for i in range(len(x)):
                x[i] = np.squeeze(x[i])
            x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
            x, y  = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = logits.argmax(1)
        
            y_pred.append(preds.cpu().numpy())
            y_true.append(y.cpu().numpy())

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)
    
    return y_true, y_pred

y_true, y_pred = predict(model, test_dataloader, device)
sklearn_classification_report(y_true, y_pred)

# Attention Map

In [None]:
# code inspired by "https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb"
def get_attention(image, attention, device, rotate=False):    
    image_size = image.shape[-1]
    
    att_mat = torch.stack(attention).squeeze(1)
    att_mat = torch.mean(att_mat, dim=1)
    
    residual_att = torch.eye(att_mat.size(1)).to(device)
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n].to(device), joint_attentions[n-1].to(device))

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    heatmap = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    heatmap = cv2.resize(heatmap / heatmap.max(), (image_size, image_size))[..., np.newaxis]
    # Duplicate heatmap to form a 3 channel image
    heatmap = np.concatenate([heatmap]*3, axis=2)
        
    image = image.permute(1, 2, 0).numpy()
    
    result = (heatmap * 3.5) + image
    
    if rotate:
        heatmap = np.rot90(heatmap)
        result = np.rot90(result)
        
    return result, heatmap

In [None]:
idx = 2
image, label = test_ds[idx]
logits, attention = model(image.unsqueeze(0).to(device))
print("Label:", id2label[label.item()])
print("Prediction:", id2label[torch.argmax(logits).item()])
img, att_map = get_attention(image, attention, device, rotate=True)

# Mask
image = image.permute(1, 2, 0).numpy()
image = np.rot90(image)
img = img * np.where(image>0, 1, 0)
img = img [:, :, 2]

# Normalize
img = (img - img.min()) / (img.max() - img.min())

fig, ax = plt.subplots(dpi=300)
im = ax.imshow(img, cmap='plasma')
ax.axis("off");
cbar = fig.colorbar(im);