In [None]:
import os
import copy
import pickle
import torch
import random
import time
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import confusion_matrix, roc_auc_score
from torchcam.cams import GradCAMpp, SmoothGradCAMpp
from scipy.ndimage import zoom
from nilearn.datasets import fetch_atlas_aal

___

# Config

In [None]:
train_model = True
train_full = True
train_CV = True
test_model = False
produce_CAM = False
generateSplits = False

folder = 'k10b'

___

# AAL

In [None]:
aal_img = nib.load('./AAL/AAL.nii').get_fdata()[5:85, 8:103, 3:80]

file = open("./AAL/labels.pkl", "rb")
aal_labels = pickle.load(file)
file.close()

___

# Data

In [None]:
class ADNIsetPreloaded(torch.utils.data.Dataset):
    
    def __init__(self, images, labels, classes, include):
        super().__init__()
        
        self._prepare_mask(include)
        
        self.classes = classes
        self.labels = labels
        self.images = self._mask_images(images)
        
    def getClassCounts(self):
        counts = {0: 0, 1: 0, 2: 0}
        for label in self.labels:
            counts[label] += 1
        return counts
    
    def showImage(self, idx):
        plt.imshow(self.images[idx][0, :, :, 45])
        
    def _prepare_mask(self, include):
        self.mask = aal_img == -1
        for region in aal_labels.keys():
            if not region in include:
                self.mask = self.mask | (aal_img == aal_labels[region])
        
    def _compute_labels(self, imagePaths, classes):
        labels = []
        for imagePath in imagePaths:
            labels.append(classes[imagePath.split('/')[-2]])
        return labels
    
    def _mask_images(self, images):
        for image in images:
            image[0][self.mask] = 0
        return images
    
    def __getitem__(self, idx):
        x = self.images[idx]
        y = self.labels[idx]
        
        return x, y
    
    def __len__(self):
        return len(self.images)        

In [None]:
def getAllImages(shuffler):
    path = '../thesis-data2/ADNI_Soft/'
    cn_files = os.listdir(path + 'CN/')
    mci_files = os.listdir(path + 'MCI/')
    ad_files = os.listdir(path + 'AD/')
    
    for category, files in zip(['CN', 'MCI', 'AD'], [cn_files, mci_files, ad_files]):
        for i in range(len(files)):
            files[i] = path + category + '/' + files[i]
    
    scaler = (torch.linspace(-1, 1, aal_img.shape[0]), torch.linspace(-1, 1, aal_img.shape[1]), torch.linspace(-1, 1, aal_img.shape[2]))
    meshz, meshy, meshx = torch.meshgrid(scaler)
    grid = torch.stack((meshx, meshy, meshz), 3)
    grid = grid.unsqueeze(0)
    rescaler = lambda x: F.grid_sample(x, grid, align_corners = True)
        
    images = []
    for file in cn_files + mci_files + ad_files:
        x = torch.from_numpy(nib.load(file).get_fdata()[np.newaxis, :, :, :]).float()
        x = rescaler(x[np.newaxis, :, :, :, :])[0].numpy()
        images.append(x)
        
    images = np.array(images)
    labels = np.array([0] * len(cn_files) + [1] * len(mci_files) + [2] * len(ad_files))
    
    return images[shuffler], labels[shuffler]

In [None]:
def splitData(images, labels, ratio):
    skf = StratifiedKFold(ratio, shuffle = True)
    
    for train_idxs, test_idxs in skf.split(images, labels):
        train_images = images[train_idxs]
        train_labels = labels[train_idxs]
        all_train_idxs = train_idxs
        break
        
    return all_train_idxs, test_idxs

In [None]:
def foldData(images, labels):
    skf = StratifiedKFold(10, shuffle = True)
        
    train_splits = []
    val_splits = []
    
    for train_idxs, val_idxs in skf.split(images, labels):
        train_splits.append(train_idxs)
        val_splits.append(val_idxs)
        
    return train_splits, val_splits

In [None]:
def createSet(images, labels, include):
    classes = {'CN': 0, 'MCI': 1, 'AD': 2}

    dataset = ADNIsetPreloaded(images, labels, classes, include = include)
    loader = torch.utils.data.DataLoader(dataset, batch_size = 6, shuffle = True)
    
    return loader

In [None]:
class Net(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.conv_1 = nn.Conv3d(1, 16, kernel_size = (3, 3, 3))
        self.pool_1 = nn.MaxPool3d((2, 2, 2))
        self.batch_1 = nn.BatchNorm3d(16)
        
        self.conv_2 = nn.Conv3d(16, 32, kernel_size = (3, 3, 3))
        self.pool_2 = nn.MaxPool3d((2, 2, 2))
        self.batch_2 = nn.BatchNorm3d(32)
        
        self.conv_3 = nn.Conv3d(32, 64, kernel_size = (3, 3, 3))
        self.pool_3 = nn.MaxPool3d((2, 2, 2))
        self.batch_3 = nn.BatchNorm3d(64)
        
        self.fc_1 = nn.Linear(35840, 128)
        self.fc_2 = nn.Linear(128, 64)
        self.fc_3 = nn.Linear(64, 3)
        
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.conv_1(x)
        x = F.leaky_relu(x)
        x = self.pool_1(x)
        x = self.batch_1(x)
        
        x = self.conv_2(x)
        x = F.leaky_relu(x)
        x = self.pool_2(x)
        x = self.batch_2(x)
        
        x = self.conv_3(x)
        x = F.leaky_relu(x)
        x = self.pool_3(x)
        x = self.batch_3(x)
        
        x = x.view(-1, self.num_flat_features(x))
        
        x = self.fc_1(x)
        x = F.leaky_relu(x)
        x = self.dropout(x)
        
        x = self.fc_2(x)
        x = F.leaky_relu(x)
        x = self.dropout(x)
        
        return self.fc_3(x)
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

___

# Training

In [None]:
def prepareNet(classCounts, cn = 1, mci = 1, ad = 1):
    net = Net().cuda()
    balance = max(classCounts.values())
    criterion = nn.CrossEntropyLoss(weight = torch.tensor([cn * classCounts[0] / balance, mci * classCounts[1] / balance, ad * classCounts[2] / balance]).cuda())

    if train_model:
        optimizer = torch.optim.Adadelta(net.parameters(), lr = 0.01, weight_decay = 0.00001)
    else:
        optimizer = None
        
    return net, criterion, optimizer

In [None]:
def trainNet(net, criterion, optimizer, trainloader, valloader, epochs = 35, verbose = True, save = True):
    if train_model:
        
        best_val_acc_epoch = 0
        best_val_acc = 0
        best_val_acc_raw_outputs = None
        best_val_acc_labels = None
        
        best_val_auc_epoch = 0
        best_val_auc = 0
        best_val_auc_raw_outputs = None
        best_val_auc_labels = None
        
        best_val_loss_epoch = 0
        best_val_loss = 1000000
        best_val_loss_raw_outputs = None
        best_val_loss_labels = None
        
        startTime = time.time()
        
        for epoch in range(epochs):

            train_loss = 0
            train_correct = 0
            val_loss = 0
            val_correct = 0
            
            train_raw_outputs = np.zeros((len(trainloader.dataset), 3))
            train_outputs = np.zeros((len(trainloader.dataset), 3))
            train_predictions = np.zeros((len(trainloader.dataset)))
            train_labels = np.zeros((len(trainloader.dataset)))

            val_raw_outputs = np.zeros((len(valloader.dataset), 3))
            val_outputs = np.zeros((len(valloader.dataset), 3))
            val_predictions = np.zeros((len(valloader.dataset)))
            val_labels = np.zeros((len(valloader.dataset)))

            net.train()
            for i, data in enumerate(trainloader):
                inputs = data[0].cuda()
                labels = data[1].cuda()

                optimizer.zero_grad()

                outputs = net(inputs.cuda())

                loss = criterion(outputs, labels.cuda())
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                train_correct += torch.sum(torch.argmax(outputs, dim = 1) == labels)
                
                for j, (output, label) in enumerate(zip(outputs, labels)):
                    train_raw_outputs[i * trainloader.batch_size + j] = output.detach().cpu().numpy()
                    train_outputs[i * trainloader.batch_size + j] = nn.functional.softmax(output, dim = 0).detach().cpu().numpy()
                    train_predictions[i * trainloader.batch_size + j] = torch.argmax(output, dim = 0).cpu().numpy()
                    train_labels[i * trainloader.batch_size + j] = label.cpu().numpy()

            net.eval()
            with torch.no_grad():
                for i, data in enumerate(valloader):
                    inputs = data[0].cuda()
                    labels = data[1].cuda()

                    outputs = net(inputs.cuda())
                    loss = criterion(outputs, labels.cuda())

                    val_loss += loss.item()
                    val_correct += torch.sum(torch.argmax(outputs, dim = 1) == labels)

                    for j, (output, label) in enumerate(zip(outputs, labels)):
                        val_raw_outputs[i * valloader.batch_size + j] = output.detach().cpu().numpy()
                        val_outputs[i * valloader.batch_size + j] = nn.functional.softmax(output, dim = 0).detach().cpu().numpy()
                        val_predictions[i * valloader.batch_size + j] = torch.argmax(output, dim = 0).cpu().numpy()
                        val_labels[i * valloader.batch_size + j] = label.cpu().numpy()
                    
            if verbose:
                print('T-' + str(int(((time.time() - startTime) / (epoch + 1)) * (epochs - epoch - 1))), 'Epoch:', epoch + 1,
                      '~ Train Loss:', int(1000 * train_loss / len(trainloader)) / 1000,
                      '~ Train Acc:', int(1000 * train_correct / len(trainloader.dataset)) / 1000,
                      '~ Train AuC:', int(1000 * roc_auc_score(train_labels, train_outputs, multi_class = 'ovo')) / 1000,
                      '~ Val Loss:', int(1000 * val_loss / len(valloader)) / 1000,
                      '~ Val Acc:', int(1000 * val_correct / len(valloader.dataset)) / 1000,
                      '~ Val AuC:', int(1000 * roc_auc_score(val_labels, val_outputs, multi_class = 'ovo')) / 1000)

            if val_correct / len(valloader.dataset) > best_val_acc:
                best_val_acc_epoch = epoch
                best_val_acc = val_correct / len(valloader.dataset)
                best_val_acc_raw_outputs = val_raw_outputs
                best_val_acc_labels = val_labels
                
                if save:
                    torch.save(net.state_dict(), 'model.pt')
                if verbose:
                    print('Saving new best model')
                    
            if roc_auc_score(val_labels, val_outputs, multi_class = 'ovo') > best_val_auc:
                best_val_auc_epoch = epoch
                best_val_auc = roc_auc_score(val_labels, val_outputs, multi_class = 'ovo')
                best_val_auc_raw_outputs = val_raw_outputs
                best_val_auc_labels = val_labels
                    
            if val_loss / len(valloader) < best_val_loss:
                best_val_loss_epoch = epoch
                best_val_loss = val_loss / len(valloader)
                best_val_loss_raw_outputs = val_raw_outputs
                best_val_loss_labels = val_labels

            if verbose and (epoch + 1) % 10 == 0:
                print(confusion_matrix(np.array(val_labels).flatten(), np.array(val_predictions).flatten()))
                
        print('Best accuracy (', best_val_acc, ') during epoch', best_val_acc_epoch, '. Best AuC (', best_val_auc, ') during epoch', best_val_auc_epoch, '. Best Loss (', best_val_loss, ') during epoch', best_val_loss_epoch)
        return best_val_acc_raw_outputs, best_val_acc_labels, best_val_auc_raw_outputs, best_val_auc_labels, best_val_loss_raw_outputs, best_val_loss_labels

In [None]:
shuffler = np.load(folder + '/shuffler.npy') 

images, labels = getAllImages(shuffler)

if generateSplits:
    all_train_idxs, test_idxs = splitData(images, labels, 10)
    train_splits, val_splits = foldData(images[all_train_idxs], labels[all_train_idxs])
    pure_train_idxs, pure_val_idxs = splitData(images[all_train_idxs], labels[all_train_idxs], 5)
    
    with open('train_splits.pickle', 'wb') as fp:
        pickle.dump(train_splits, fp)
    with open('val_splits.pickle', 'wb') as fp:
        pickle.dump(val_splits, fp)
    with open('all_train_idxs.pickle', 'wb') as fp:
        pickle.dump(all_train_idxs, fp)
    with open('test_idxs.pickle', 'wb') as fp:
        pickle.dump(test_idxs, fp)
    with open('pure_train_idxs.pickle', 'wb') as fp:
        pickle.dump(pure_train_idxs, fp)
    with open('pure_val_idxs.pickle', 'wb') as fp:
        pickle.dump(pure_val_idxs, fp)
        
else:
    with open(folder + '/train_splits.pickle', 'rb') as fp:
        train_splits = pickle.load(fp)
    with open(folder + '/val_splits.pickle', 'rb') as fp:
        val_splits = pickle.load(fp)
    with open(folder + '/all_train_idxs.pickle', 'rb') as fp:
        all_train_idxs = pickle.load(fp)
    with open(folder + '/test_idxs.pickle', 'rb') as fp:
        test_idxs = pickle.load(fp)
    with open(folder + '/pure_train_idxs.pickle', 'rb') as fp:
        pure_train_idxs = pickle.load(fp)
    with open(folder + '/pure_val_idxs.pickle', 'rb') as fp:
        pure_val_idxs = pickle.load(fp)

In [None]:
include_all = list(aal_labels.keys())
exclude_background = list(aal_labels.keys())
exclude_background.remove('Background')

include_left = ['Precentral_L', 'Frontal_Sup_L', 'Frontal_Sup_Orb_L',
       'Frontal_Mid_L', 'Frontal_Mid_Orb_L', 'Frontal_Inf_Oper_L',
       'Frontal_Inf_Tri_L', 'Frontal_Inf_Orb_L', 'Rolandic_Oper_L',
       'Supp_Motor_Area_L', 'Olfactory_L', 'Frontal_Sup_Medial_L',
       'Frontal_Med_Orb_L', 'Rectus_L', 'Insula_L', 'Cingulum_Ant_L',
       'Cingulum_Mid_L', 'Cingulum_Post_L', 'Hippocampus_L',
       'ParaHippocampal_L', 'Amygdala_L', 'Calcarine_L', 'Cuneus_L',
       'Lingual_L', 'Occipital_Sup_L', 'Occipital_Mid_L',
       'Occipital_Inf_L', 'Fusiform_L', 'Postcentral_L', 'Parietal_Sup_L',
       'Parietal_Inf_L', 'SupraMarginal_L', 'Angular_L', 'Precuneus_L',
       'Paracentral_Lobule_L', 'Caudate_L', 'Putamen_L', 'Pallidum_L',
       'Thalamus_L', 'Heschl_L', 'Temporal_Sup_L', 'Temporal_Pole_Sup_L',
       'Temporal_Mid_L', 'Temporal_Pole_Mid_L', 'Temporal_Inf_L',
       'Cerebelum_Crus1_L', 'Cerebelum_Crus2_L', 'Cerebelum_3_L',
       'Cerebelum_4_5_L', 'Cerebelum_6_L', 'Cerebelum_7b_L',
       'Cerebelum_8_L', 'Cerebelum_9_L', 'Cerebelum_10_L']

include_right = ['Precentral_R', 'Frontal_Sup_R', 'Frontal_Sup_Orb_R',
       'Frontal_Mid_R', 'Frontal_Mid_Orb_R', 'Frontal_Inf_Oper_R',
       'Frontal_Inf_Tri_R', 'Frontal_Inf_Orb_R', 'Rolandic_Oper_R',
       'Supp_Motor_Area_R', 'Olfactory_R', 'Frontal_Sup_Medial_R',
       'Frontal_Med_Orb_R', 'Rectus_R', 'Insula_R', 'Cingulum_Ant_R',
       'Cingulum_Mid_R', 'Cingulum_Post_R', 'Hippocampus_R',
       'ParaHippocampal_R', 'Amygdala_R', 'Calcarine_R', 'Cuneus_R',
       'Lingual_R', 'Occipital_Sup_R', 'Occipital_Mid_R',
       'Occipital_Inf_R', 'Fusiform_R', 'Postcentral_R', 'Parietal_Sup_R',
       'Parietal_Inf_R', 'SupraMarginal_R', 'Angular_R', 'Precuneus_R',
       'Paracentral_Lobule_R', 'Caudate_R', 'Putamen_R', 'Pallidum_R',
       'Thalamus_R', 'Heschl_R', 'Temporal_Sup_R', 'Temporal_Pole_Sup_R',
       'Temporal_Mid_R', 'Temporal_Pole_Mid_R', 'Temporal_Inf_R',
       'Cerebelum_Crus1_R', 'Cerebelum_Crus2_R', 'Cerebelum_3_R',
       'Cerebelum_4_5_R', 'Cerebelum_6_R', 'Cerebelum_7b_R',
       'Cerebelum_8_R', 'Cerebelum_9_R', 'Cerebelum_10_R']

if train_model:
    
    if train_full:
    
        if train_CV:
            results = {}
            
            for name, inclusion in zip(['left', 'right'], [include_left, include_right]):

                for i, (train_idxs, val_idxs) in enumerate(zip(train_splits, val_splits)):
                    print('Split', i)

                    trainloader = createSet(images[all_train_idxs][train_idxs], labels[all_train_idxs][train_idxs], inclusion)
                    valloader = createSet(images[all_train_idxs][val_idxs], labels[all_train_idxs][val_idxs], inclusion)

                    net, criterion, optimizer = prepareNet(trainloader.dataset.getClassCounts())
                    acc_outputs, acc_labels, auc_outputs, auc_labels, loss_outputs, loss_labels = trainNet(net, criterion, optimizer, trainloader, valloader, epochs = 50, verbose = False, save = False)

                    results[i] = (acc_outputs, acc_labels, auc_outputs, auc_labels, loss_outputs, loss_labels)
                    with open(name + '_results.pickle', 'wb') as fp:
                        pickle.dump(results, fp)
                    
        else:
            trainloader = createSet(images[all_train_idxs][pure_train_idxs], labels[all_train_idxs][pure_train_idxs], include_all)
            valloader = createSet(images[all_train_idxs][pure_val_idxs], labels[all_train_idxs][pure_val_idxs], include_all)
            net, criterion, optimizer = prepareNet(trainloader.dataset.getClassCounts())
            trainNet(net, criterion, optimizer, trainloader, valloader, epochs = 50, verbose = True, save = True)
            
    else:
        results = {}
        
        with open(folder + '/stats.npy', 'rb') as fp:
            rankings = pickle.load(fp)
        
        rankedRegions = list(rankings['All']['Intensities'].keys())
        rankedRegions.remove('Background')
        
        for i in range(90, 9, -10):
            print('Ignoring the worst', i, 'regions')
            
            for j, (train_idxs, val_idxs) in enumerate(zip(train_splits, val_splits)):
                print('Split', j)

                trainloader = createSet(images[all_train_idxs][train_idxs], labels[all_train_idxs][train_idxs], rankedRegions[-(116-i):])
                valloader = createSet(images[all_train_idxs][val_idxs], labels[all_train_idxs][val_idxs], rankedRegions[-(116-i):])

                net, criterion, optimizer = prepareNet(trainloader.dataset.getClassCounts())
                acc_outputs, acc_labels, auc_outputs, auc_labels, loss_outputs, loss_labels = trainNet(net, criterion, optimizer, trainloader, valloader, epochs = 50, verbose = False, save = False)
                
                results[(i, j)] = (acc_outputs, acc_labels, auc_outputs, auc_labels, loss_outputs, loss_labels)
                with open('sub_results_normal.pickle', 'wb') as fp:
                    pickle.dump(results, fp)
                    
        results = {}
        
        rankedRegions.reverse()
        
        for i in range(90, 9, -10):
            print('Ignoring the best', i, 'regions')
            
            for j, (train_idxs, val_idxs) in enumerate(zip(train_splits, val_splits)):
                print('Split', j)

                trainloader = createSet(images[all_train_idxs][train_idxs], labels[all_train_idxs][train_idxs], rankedRegions[-(116-i):])
                valloader = createSet(images[all_train_idxs][val_idxs], labels[all_train_idxs][val_idxs], rankedRegions[-(116-i):])

                net, criterion, optimizer = prepareNet(trainloader.dataset.getClassCounts())
                acc_outputs, acc_labels, auc_outputs, auc_labels, loss_outputs, loss_labels = trainNet(net, criterion, optimizer, trainloader, valloader, epochs = 50, verbose = False, save = False)
                
                results[(i, j)] = (acc_outputs, acc_labels, auc_outputs, auc_labels, loss_outputs, loss_labels)
                with open('sub_results_reverse.pickle', 'wb') as fp:
                    pickle.dump(results, fp)
        
else:
    net = Net().cuda()
    net.load_state_dict(torch.load(folder + '/model.pt'))

In [None]:
if test_model:
    
    #testloader = createSet(images[test_idxs], labels[test_idxs], include_all)
    #testloader = createSet(images[all_train_idxs][pure_train_idxs], labels[all_train_idxs][pure_train_idxs], include_all)
    #net.load_state_dict(torch.load('model.pt'))
    net, criterion, optimizer = prepareNet(testloader.dataset.getClassCounts())
    net.load_state_dict(torch.load(folder + '/ensamble/model_' + 'ia_1' + '.pt'))
    
    test_loss = 0
    test_correct = 0

    test_outputs = np.zeros((len(testloader.dataset), 3))
    test_predictions = np.zeros((len(testloader.dataset)))
    test_labels = np.zeros((len(testloader.dataset)))
    
    net.eval()
    with torch.no_grad():
        for i, data in enumerate(testloader):
            inputs = data[0].cuda()
            labels = data[1].cuda()

            outputs = net(inputs.cuda())
            loss = criterion(outputs, labels.cuda())

            test_loss += loss.item()
            test_correct += torch.sum(torch.argmax(outputs, dim = 1) == labels)
            
            for j, (output, label) in enumerate(zip(outputs, labels)):
                test_outputs[i * testloader.batch_size + j] = nn.functional.softmax(output, dim = 0).detach().cpu().numpy()
                test_predictions[i * testloader.batch_size + j] = torch.argmax(output, dim = 0).cpu().numpy()
                test_labels[i * testloader.batch_size + j] = label.cpu().numpy()

    print('Test Loss:', int(1000 * test_loss / len(testloader)) / 1000,
          '~ Test Acc:', int(1000 * test_correct / len(testloader.dataset)) / 1000,
          '~ Val AuC:', int(1000 * roc_auc_score(test_labels, test_outputs, multi_class = 'ovo')) / 1000)

    print(confusion_matrix(np.array(test_labels).flatten(), np.array(test_predictions).flatten()))

___

# CAM

In [None]:
images = createSet(images[all_train_idxs][pure_val_idxs], labels[all_train_idxs][pure_val_idxs], include_all).dataset

In [None]:
ensamble = 'ia_10'

In [None]:
if produce_CAM:
    
    net = Net().cuda()
    net.load_state_dict(torch.load(folder + '/ensamble/model_' + ensamble + '.pt'))
    net.eval()
    
    cam_extractor = GradCAMpp(net, input_shape = [1, aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]], target_layer = 'conv_3')

In [None]:
if produce_CAM:

    combined_activation_map = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    combined_activation_map_CN = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    combined_activation_map_MCI = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    combined_activation_map_AD = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    
    combined_activation_map_wrong = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    combined_activation_map_wrong_CN = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    combined_activation_map_wrong_MCI = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    combined_activation_map_wrong_AD = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    
    overlap_activation_map = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    overlap_activation_map_CN = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    overlap_activation_map_MCI = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    overlap_activation_map_AD = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    
    overlap_activation_map_wrong = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    overlap_activation_map_wrong_CN = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    overlap_activation_map_wrong_MCI = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    overlap_activation_map_wrong_AD = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
    
    image_count_All = 0
    image_count_CN = 0
    image_count_MCI = 0
    image_count_AD = 0
    
    image_count_wrong_All = 0
    image_count_wrong_CN = 0
    image_count_wrong_MCI = 0
    image_count_wrong_AD = 0

    for i, (image, label) in enumerate(zip(images.images, images.labels)):
        
        x = torch.from_numpy(image).cuda()[np.newaxis, :, :, :, :]

        class_scores = net(x)
        class_idx = class_scores.squeeze(0).argmax().item()
        
        if class_idx == label:
            image_count_All += 1
            combined_activation_map_all = combined_activation_map
            overlap_activation_map_all = overlap_activation_map
            if label == 0:
                combined_activation_map_condition = combined_activation_map_CN
                overlap_activation_map_condition = overlap_activation_map_CN
                image_count_CN += 1
            elif label == 1:
                combined_activation_map_condition = combined_activation_map_MCI
                overlap_activation_map_condition = overlap_activation_map_MCI
                image_count_MCI += 1
            else:
                combined_activation_map_condition = combined_activation_map_AD
                overlap_activation_map_condition = overlap_activation_map_AD
                image_count_AD += 1
        else:
            image_count_wrong_All += 1
            combined_activation_map_all = combined_activation_map_wrong
            overlap_activation_map_all = overlap_activation_map_wrong
            if label == 0:
                combined_activation_map_condition = combined_activation_map_wrong_CN
                overlap_activation_map_condition = overlap_activation_map_wrong_CN
                image_count_wrong_CN += 1
            elif label == 1:
                combined_activation_map_condition = combined_activation_map_wrong_MCI
                overlap_activation_map_condition = overlap_activation_map_wrong_MCI
                image_count_wrong_MCI += 1
            else:
                combined_activation_map_condition = combined_activation_map_wrong_AD
                overlap_activation_map_condition = overlap_activation_map_wrong_AD
                image_count_wrong_AD += 1

        activation_map = cam_extractor(class_idx, class_scores).cpu().numpy()
        scaled_activation_map = zoom(activation_map, (x.shape[2] / activation_map.shape[0], x.shape[3] / activation_map.shape[1], x.shape[4] / activation_map.shape[2]))

        zero_mask = x == 0
        scaled_activation_map[zero_mask.cpu().numpy()[0, 0, :, :, :]] = 0

        normalized_actvation_map = scaled_activation_map# / scaled_activation_map.sum()
        
        combined_activation_map_all += normalized_actvation_map
        combined_activation_map_condition += normalized_actvation_map
        
        scaled_activation_map[scaled_activation_map > 0] = 1
        scaled_activation_map[scaled_activation_map < 0] = 0
        
        overlap_activation_map_all += scaled_activation_map
        overlap_activation_map_condition += scaled_activation_map
    
    if image_count_All > 0:
        combined_activation_map = combined_activation_map / image_count_All
    if image_count_CN > 0:
        combined_activation_map_CN = combined_activation_map_CN / image_count_CN
    if image_count_MCI > 0:
        combined_activation_map_MCI = combined_activation_map_MCI / image_count_MCI
    if image_count_AD > 0:
        combined_activation_map_AD = combined_activation_map_AD / image_count_AD
    
    if image_count_wrong_All > 0:
        combined_activation_map_wrong = combined_activation_map_wrong / image_count_wrong_All
    if image_count_wrong_CN > 0:
        combined_activation_map_wrong_CN = combined_activation_map_wrong_CN / image_count_wrong_CN
    if image_count_wrong_MCI > 0:
        combined_activation_map_wrong_MCI = combined_activation_map_wrong_MCI / image_count_wrong_MCI
    if image_count_wrong_AD > 0:
        combined_activation_map_wrong_AD = combined_activation_map_wrong_AD / image_count_wrong_AD
    
    if image_count_All > 0:
        overlap_activation_map = overlap_activation_map / image_count_All
    if image_count_CN > 0:
        overlap_activation_map_CN = overlap_activation_map_CN / image_count_CN
    if image_count_MCI > 0:
        overlap_activation_map_MCI = overlap_activation_map_MCI / image_count_MCI
    if image_count_AD > 0:
        overlap_activation_map_AD = overlap_activation_map_AD / image_count_AD
        
    if image_count_wrong_All > 0:
        overlap_activation_map_wrong = overlap_activation_map_wrong / image_count_wrong_All
    if image_count_wrong_CN > 0:
        overlap_activation_map_wrong_CN = overlap_activation_map_wrong_CN / image_count_wrong_CN
    if image_count_wrong_MCI > 0:
        overlap_activation_map_wrong_MCI = overlap_activation_map_wrong_MCI / image_count_wrong_MCI
    if image_count_wrong_AD > 0:
        overlap_activation_map_wrong_AD = overlap_activation_map_wrong_AD / image_count_wrong_AD

In [None]:
if produce_CAM:
    np.save(folder + '/ensamble/Map_val_All_' + ensamble + '.npy', combined_activation_map)
    np.save(folder + '/ensamble/Map_val_CN_' + ensamble + '.npy', combined_activation_map_CN)
    np.save(folder + '/ensamble/Map_val_MCI_' + ensamble + '.npy', combined_activation_map_MCI)
    np.save(folder + '/ensamble/Map_val_AD_' + ensamble + '.npy', combined_activation_map_AD)
    
    np.save(folder + '/ensamble/Map_val_wrong_All_' + ensamble + '.npy', combined_activation_map_wrong)
    np.save(folder + '/ensamble/Map_val_wrong_CN_' + ensamble + '.npy', combined_activation_map_wrong_CN)
    np.save(folder + '/ensamble/Map_val_wrong_MCI_' + ensamble + '.npy', combined_activation_map_wrong_MCI)
    np.save(folder + '/ensamble/Map_val_wrong_AD_' + ensamble + '.npy', combined_activation_map_wrong_AD)
    
    np.save(folder + '/ensamble/Map_val_All_overlap_' + ensamble + '.npy', overlap_activation_map)
    np.save(folder + '/ensamble/Map_val_CN_overlap_' + ensamble + '.npy', overlap_activation_map_CN)
    np.save(folder + '/ensamble/Map_val_MCI_overlap_' + ensamble + '.npy', overlap_activation_map_MCI)
    np.save(folder + '/ensamble/Map_val_AD_overlap_' + ensamble + '.npy', overlap_activation_map_AD)
    
    np.save(folder + '/ensamble/Map_val_wrong_All_overlap_' + ensamble + '.npy', overlap_activation_map_wrong)
    np.save(folder + '/ensamble/Map_val_wrong_CN_overlap_' + ensamble + '.npy', overlap_activation_map_wrong_CN)
    np.save(folder + '/ensamble/Map_val_wrong_MCI_overlap_' + ensamble + '.npy', overlap_activation_map_wrong_MCI)
    np.save(folder + '/ensamble/Map_val_wrong_AD_overlap_' + ensamble + '.npy', overlap_activation_map_wrong_AD)

In [None]:
if not produce_CAM:
    combined_activation_map = np.load(folder + '/Map_All.npy')
    combined_activation_map_CN = np.load(folder + '/Map_CN.npy')
    combined_activation_map_MCI = np.load(folder + '/Map_MCI.npy')
    combined_activation_map_AD = np.load(folder + '/Map_AD.npy')

In [None]:
average = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
cn_average = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
mci_average = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))
ad_average = np.zeros((aal_img.shape[0], aal_img.shape[1], aal_img.shape[2]))

cn_count = 0
mci_count = 0
ad_count = 0

for image in images:
    average += image[0][0]
    if image[1] == 0:
        cn_average += image[0][0]
        cn_count += 1
    elif image[1] == 1:
        mci_average += image[0][0]
        mci_count += 1
    else:
        ad_average += image[0][0]
        ad_count += 1
        
average = average / (cn_count + mci_count + ad_count)
cn_average = cn_average / cn_count
mci_average = mci_average / mci_count
ad_average = ad_average / ad_count

np.save('average.npy', average)
np.save('average_CN.npy', cn_average)
np.save('average_MCI.npy', mci_average)
np.save('average_AD.npy', ad_average)

In [None]:
slice_index = 45
#vmax = None
vmax = max(combined_activation_map_CN.max(), combined_activation_map_MCI.max(), combined_activation_map_AD.max())
fig, axs = plt.subplots(3, 6, figsize = (25, 12))
axs[0, 0].imshow(combined_activation_map[:, :, slice_index], vmax = vmax)
axs[0, 0].set_title('All')
axs[0, 1].imshow(combined_activation_map_CN[:, :, slice_index], vmax = vmax)
axs[0, 1].set_title('CN')
axs[0, 2].imshow(combined_activation_map_MCI[:, :, slice_index], vmax = vmax)
axs[0, 2].set_title('MCI')
axs[0, 3].imshow(combined_activation_map_AD[:, :, slice_index], vmax = vmax)
axs[0, 3].set_title('AD')
axs[0, 4].imshow(np.absolute(combined_activation_map_AD - combined_activation_map_CN)[:, :, slice_index], vmax = vmax)
axs[0, 4].set_title('|CN - AD|')
axs[0, 5].set_visible(False)
axs[1, 0].imshow(aal_img[:, :, slice_index], cmap = 'gray')
axs[1, 0].imshow(combined_activation_map[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[1, 0].set_title('All')
axs[1, 1].imshow(aal_img[:, :, slice_index], cmap = 'gray')
axs[1, 1].imshow(combined_activation_map_CN[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[1, 1].set_title('CN')
axs[1, 2].imshow(aal_img[:, :, slice_index], cmap = 'gray')
axs[1, 2].imshow(combined_activation_map_MCI[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[1, 2].set_title('MCI')
axs[1, 3].imshow(aal_img[:, :, slice_index], cmap = 'gray')
axs[1, 3].imshow(combined_activation_map_AD[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[1, 3].set_title('AD')
axs[1, 4].imshow(aal_img[:, :, slice_index], cmap = 'gray')
axs[1, 4].imshow(np.absolute(combined_activation_map_AD - combined_activation_map_CN)[:, :, slice_index], alpha = 0.75, vmax = vmax)
axs[1, 4].set_title('|CN - AD|')
axs[1, 5].imshow(aal_img[:, :, slice_index])
axs[1, 5].set_title('AAL atlas')
axs[2, 0].imshow(average[:, :, slice_index], cmap = 'gray')
axs[2, 0].imshow(combined_activation_map[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[2, 0].set_title('All')
axs[2, 1].imshow(average[:, :, slice_index], cmap = 'gray')
axs[2, 1].imshow(combined_activation_map_CN[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[2, 1].set_title('CN')
axs[2, 2].imshow(average[:, :, slice_index], cmap = 'gray')
axs[2, 2].imshow(combined_activation_map_MCI[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[2, 2].set_title('MCI')
axs[2, 3].imshow(average[:, :, slice_index], cmap = 'gray')
axs[2, 3].imshow(combined_activation_map_AD[:, :, slice_index], alpha = 0.85, vmax = vmax)
axs[2, 3].set_title('AD')
axs[2, 4].imshow(average[:, :, slice_index], cmap = 'gray')
axs[2, 4].imshow(np.absolute(combined_activation_map_AD - combined_activation_map_CN)[:, :, slice_index], alpha = 0.75, vmax = vmax)
axs[2, 4].set_title('|CN - AD|')
axs[2, 5].imshow(average[:, :, slice_index], cmap = 'gray')
axs[2, 5].set_title('Average Scan')
pass

___

# AAL

In [None]:
if produce_CAM:
    
    all_stats = {}
    
    for stats, CAM in zip(['All', 'CN', 'MCI', 'AD', 'AD-CN'], [combined_activation_map, combined_activation_map_CN, combined_activation_map_MCI, combined_activation_map_AD, combined_activation_map_AD - combined_activation_map_CN]):
    
        volumes = {}
        intensities = {}
        densities = {}

        for key in aal_labels.keys():
            mask = aal_img != aal_labels[key]
            masked_cam = copy.copy(CAM)
            masked_cam[mask] = 0

            volumes[key] = mask.size - np.count_nonzero(mask)
            intensities[key] = masked_cam.sum()
            densities[key] = intensities[key] / volumes[key]

        all_stats[stats] = {}
        all_stats[stats]['Volume'] = dict(sorted(volumes.items(), key = lambda item: item[1], reverse = False))
        all_stats[stats]['Intensities'] = dict(sorted(intensities.items(), key = lambda item: item[1], reverse = False))
        all_stats[stats]['Densities'] = dict(sorted(densities.items(), key = lambda item: item[1], reverse = False))

    for stats, CAM in zip(['All', 'CN', 'MCI', 'AD', 'AD-CN'], [overlap_activation_map, overlap_activation_map_CN, overlap_activation_map_MCI, overlap_activation_map_AD, overlap_activation_map_AD - overlap_activation_map_CN]):
        
        overlap = {}
        
        for key in aal_labels.keys():
            mask = aal_img != aal_labels[key]
            masked_cam = copy.copy(CAM)
            masked_cam[mask] = 0
            
            overlap[key] = masked_cam.sum() / (mask.size - np.count_nonzero(mask))
            
        all_stats[stats]['Overlap'] = dict(sorted(overlap.items(), key = lambda item: item[1], reverse = False))

    with open('stats.npy', 'wb') as fp:
        pickle.dump(all_stats, fp)
        
else:
    with open(folder + '/stats.npy', 'rb') as fp:
        all_stats = pickle.load(fp)

In [None]:
all_stats_df = pd.DataFrame(columns = ['Region', 'All Intensity', 'All Intensity Rank', 'CN Intensity', 'CN Intensity Rank', 'MCI Intensity', 'MCI Intensity Rank', 'AD Intensity', 'AD Intensity Rank', 'AD-CN Intensity', 'AD-CN Intensity Rank',
                                      'All Overlap', 'All Overlap Rank', 'CN Overlap', 'CN Overlap Rank', 'MCI Overlap', 'MCI Overlap Rank', 'AD Overlap', 'AD Overlap Rank', 'AD-CN Overlap', 'AD-CN Overlap Rank'])

all_keys = list(all_stats['All']['Intensities'].keys())
cn_keys = list(all_stats['CN']['Intensities'].keys())
mci_keys = list(all_stats['MCI']['Intensities'].keys())
ad_keys = list(all_stats['AD']['Intensities'].keys())
cn_ad_keys = list(all_stats['AD-CN']['Intensities'].keys())

overlap_all_keys = list(all_stats['All']['Overlap'].keys())
overlap_cn_keys = list(all_stats['CN']['Overlap'].keys())
overlap_mci_keys = list(all_stats['MCI']['Overlap'].keys())
overlap_ad_keys = list(all_stats['AD']['Overlap'].keys())
overlap_cn_ad_keys = list(all_stats['AD-CN']['Overlap'].keys())

for key in aal_labels.keys():
    all_stats_df = all_stats_df.append({
        'Region': key,
        'All Intensity': all_stats['All']['Intensities'][key],
        'All Intensity Rank': 117 - all_keys.index(key),
        'CN Intensity': all_stats['CN']['Intensities'][key],
        'CN Intensity Rank': 117 - cn_keys.index(key),
        'MCI Intensity': all_stats['MCI']['Intensities'][key],
        'MCI Intensity Rank': 117 - mci_keys.index(key),
        'AD Intensity': all_stats['AD']['Intensities'][key],
        'AD Intensity Rank': 117 - ad_keys.index(key),
        'AD-CN Intensity': all_stats['AD-CN']['Intensities'][key],
        'AD-CN Intensity Rank': 117 - cn_ad_keys.index(key),
        'All Overlap': all_stats['All']['Overlap'][key],
        'All Overlap Rank': 117 - overlap_all_keys.index(key),
        'CN Overlap': all_stats['CN']['Overlap'][key],
        'CN Overlap Rank': 117 - overlap_cn_keys.index(key),
        'MCI Overlap': all_stats['MCI']['Overlap'][key],
        'MCI Overlap Rank': 117 - overlap_mci_keys.index(key),
        'AD Overlap': all_stats['AD']['Overlap'][key],
        'AD Overlap Rank': 117 - overlap_ad_keys.index(key),
        'AD-CN Overlap': all_stats['AD-CN']['Overlap'][key],
        'AD-CN Overlap Rank': 117 - overlap_cn_ad_keys.index(key)
    }, ignore_index = True)

In [None]:
all_stats_df_regions = all_stats_df[all_stats_df['Region'] != 'Background']

In [None]:
condition = 'AD-CN'
fig, ax = plt.subplots(figsize = (30, 10))
ax.bar(np.arange(len(all_stats_df_regions.index)), list(all_stats_df_regions.sort_values(condition + ' Intensity Rank')[condition + ' Intensity']))
ax.set_xticks(np.arange(len(all_stats_df_regions.index)))
ax.set_xticklabels(list(all_stats_df_regions.sort_values(condition + ' Intensity Rank')['Region']), rotation = 60, ha = 'right')
ax.set_yticks([])
pass

In [None]:
condition = 'AD-CN'
fig, ax = plt.subplots(figsize = (30, 10))
ax.bar(np.arange(len(all_stats_df_regions.index)), list(all_stats_df_regions.sort_values(condition + ' Overlap Rank')[condition + ' Overlap']))
ax.set_xticks(np.arange(len(all_stats_df_regions.index)))
ax.set_xticklabels(list(all_stats_df_regions.sort_values(condition + ' Overlap Rank')['Region']), rotation = 60, ha = 'right')
ax.set_yticks([])
pass

In [None]:
condition = 'All'
fig, ax = plt.subplots(figsize = (30, 10))
ax.bar(np.arange(len(all_stats_df_regions.index)), list(all_stats_df_regions.sort_values(condition + ' Intensity Rank')[condition + ' Intensity']))
ax.set_xticks(np.arange(len(all_stats_df_regions.index)))
ax.set_xticklabels(list(all_stats_df_regions.sort_values(condition + ' Intensity Rank')['Region']), rotation = 60, ha = 'right')
ax.set_yticks([])
pass