In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import Accuracy

from torchvision.transforms import Resize, Compose, GaussianBlur, RandomRotation, RandomChoice, RandomApply, RandomAffine, ColorJitter, RandomHorizontalFlip, RandomVerticalFlip
from dataloader.transforms import GaussianNoise

from copy import deepcopy

from colorama import Fore

from matplotlib import pyplot as plt
import matplotlib.colors as mcolors

from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification

from itertools import combinations
from ast import literal_eval

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader
from atlas.atlas import ReadersAtlas3Channels, AAL3Channels

from utils.utils import count_parameters, save_model
from utils.report import sklearn_classification_report, custom_classification_report

# Dataset and Dataloader Setup

In [None]:
image_size = (384, 384)
resize = Resize(size=image_size)

gaussian_blur = GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2))
gaussian_noise = GaussianNoise(mean=0, std=0.05)
random_rotation = RandomRotation(degrees=2)
random_translate = RandomAffine(degrees=0, translate=(0.01, 0.01))
color_jitter_brightness = ColorJitter(brightness=0.1)
color_jitter_contrast = ColorJitter(contrast=0.1)
color_jitter_saturation = ColorJitter(saturation=0.1)
random_vertical_flip = RandomVerticalFlip(0.5)
random_horizontal_flip = RandomHorizontalFlip(0.5)

random_choice = RandomChoice([gaussian_blur,
                              gaussian_noise,
                              color_jitter_brightness,
                              color_jitter_contrast,
                              color_jitter_saturation,
                              # random_rotation,
                              # random_translate,
                              # random_vertical_flip,
                              # random_horizontal_flip
                             ])
random_transforms = RandomApply([random_choice], p=0.7)

train_transforms = Compose([])
valid_transforms = Compose([])
test_transforms = Compose([])

In [None]:
train_ds = ADNI3Channels("../Data/Training/", transforms=train_transforms, rotate=True)
valid_ds = ADNI3Channels("../Data/Validation/", transforms=valid_transforms, rotate=True)
test_ds = ADNI3Channels("../Data/Test/", transforms=test_transforms, rotate=True)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");
    # print(image[i, :, :].min(), image[i, :, :].max())

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 10
valid_batch_size = 5
test_batch_size = 5

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': True,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': False,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Atlas

In [None]:
atlas_data, atlas_labels = ReadersAtlas3Channels(aal_dir='../Data/AAL/Resized_AAL.nii',
                                                 labels_dir='../Data/AAL/ROI_MNI_V4.txt',
                                                 rotate=True).get_data()

print(atlas_data.shape, '\n')
print(atlas_labels, '\n')
        
fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(atlas_data[i, :, :])
    axes[i].axis("off");

# Regions Combinations

In [None]:
region_comb = combinations(list(atlas_labels), 2)
region_comb = [r for r in region_comb]
len(region_comb)

# Model Development

In [None]:
pretrained = True
vit_config = ViTConfig(image_size=image_size,
                       patch_size=32,
                       num_labels=3,
                       output_attentions=True,
                       hidden_dropout_prob=0.1,
                       # attention_probs_dropout_prob=0.1,
                      )

In [None]:
class ViT(nn.Module):
    def __init__(self, num_labels=3, pretrained=False):
        super(ViT, self).__init__()
        self.pretrained = pretrained
        
        if self.pretrained:
            self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch32-384',
                                                                 output_attentions=True,
                                                                 num_labels=num_labels,
                                                                 hidden_dropout_prob=0.1,
                                                                 # attention_probs_dropout_prob=0.1,
                                                                 ignore_mismatched_sizes=True
                                                                )
        else:
            self.vit = ViTForImageClassification(vit_config)
                
        
    def forward(self, x):
        outputs = self.vit(x)
        return outputs.logits, outputs.attentions

In [None]:
# Selecting GPU
GPU = {0: torch.device('cuda:0'),
       1: torch.device('cuda:1'),
       2: torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      }

## Single-GPU trining
device = GPU[1]
model = ViT(num_labels=3, pretrained=pretrained).to(device)

## Multi-GPU training
# device = GPU[2]
# model = ViT(num_labels=3, pretrained=pretrained)
# model= nn.DataParallel(model)
# model.to(device);

if pretrained:
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384',
                                                            do_resize=False,
                                                            do_normalize=False)
else:
    feature_extractor = ViTFeatureExtractor(do_resize=False,
                                            size=image_size,
                                            do_normalize=False)


optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.15)

class_0_freq = 140
class_1_freq = 160
class_2_freq = 160
weight = torch.tensor([1/class_0_freq, 1/class_1_freq, 1/class_2_freq]).to(device)
criterion = nn.CrossEntropyLoss(weight)

accuracy = Accuracy()
writer = SummaryWriter()
scheduler = ExponentialLR(optimizer, gamma=0.999)

In [None]:
epochs = 30
train_loss_history = {comb_i: [] for comb_i in range(len(region_comb))}
train_acc_history = {comb_i: [] for comb_i in range(len(region_comb))}

for comb_i in range(len(region_comb)):
    print(Fore.WHITE + f'Region: {comb_i} {region_comb[comb_i]}')
    
    train_accs = []
    train_losses = []
    best_loss = 100
    best_acc = 0
    saved = False
    
    model = ViT(num_labels=3, pretrained=pretrained).to(device)
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384',
                                                            do_resize=False,
                                                            do_normalize=False)

    optimizer = AdamW(model.parameters(), lr=1e-4)

    class_0_freq = 140
    class_1_freq = 160
    class_2_freq = 160
    weight = torch.tensor([1/class_0_freq, 1/class_1_freq, 1/class_2_freq]).to(device)
    criterion = nn.CrossEntropyLoss()

    accuracy = Accuracy()
    writer = SummaryWriter()
    scheduler = ExponentialLR(optimizer, gamma=0.999)
    
    mask1 = torch.where(atlas_data==atlas_labels[region_comb[comb_i][0]], 1, 0)
    mask2 = torch.where(atlas_data==atlas_labels[region_comb[comb_i][1]], 1, 0)
    mask = mask1 + mask2
    atlas_subregion = atlas_data * mask

    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_dataloader):
            
            x *= atlas_subregion
            x = resize(x)
            
            x = np.split(np.array(x), train_batch_size)
            for i in range(len(x)):
                x[i] = np.squeeze(x[i])
            x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
            x, y  = x.to(device), y.to(device)
            logits, _ = model(x)
            criterion.weight = weight
            loss = criterion(logits, y)
            preds = logits.argmax(1)
            acc = accuracy(y.cpu(), preds.cpu())
            optimizer.zero_grad()           
            loss.backward()                 
            optimizer.step()
            train_losses.append(loss.item())
            train_accs.append(acc.item())

        train_loss = sum(train_losses)/len(train_losses)
        train_acc = sum(train_accs)/len(train_accs)
        
        train_loss_history[comb_i].append(train_loss)
        train_acc_history[comb_i].append(train_acc)
        
        writer.add_scalar('train_loss', train_loss, epoch * len(train_dataloader) + step)
        writer.add_scalar('train_acc', train_acc, epoch * len(train_dataloader) + step)
        
        train_losses.clear()
        train_accs.clear()
        
        

        if best_loss > train_loss:
            best_loss = train_loss
            # best_model_loss = deepcopy(model.state_dict())
            # torch.save(best_model_loss, f"best_model_loss_comb_{comb_i}.pt")
            saved = True

        if best_acc < train_acc:
            best_acc = train_acc
            # best_model_acc = deepcopy(model.state_dict())
            # torch.save(best_model_acc, f"best_model_acc_comb_{comb_i}.pt")
            saved = True

        if saved:
            print(Fore.GREEN + f"Epoch: {(epoch+1):02}/{epochs} | Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f})")
            saved = False
        else:
            print(Fore.RED + f"Epoch: {(epoch+1):02}/{epochs} | Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f})")
        
        
        model.train()

        scheduler.step()
        
    with open("history.txt", 'w') as f:
        f.write('train_loss_history = ' + str(train_loss_history) + '\n\n')
        f.write('train_acc_history = ' + str(train_acc_history))
    
    print(Fore.YELLOW + "=" * 74)

# Regions' Importance

In [None]:
# Opening the history file and converting strings to dictionaries
f = open('history_comb.txt', 'r')
txt = f.read()
txt = txt.split('\n')
f.close()

train_loss_history_str = txt[0]
train_acc_history_str = txt[-1]

train_loss_history_str = train_loss_history_str[train_loss_history_str.find('{'): train_loss_history_str.find('}') + 1]
train_acc_history_str = train_acc_history_str[train_acc_history_str.find('{'): train_acc_history_str.find('}') + 1]


train_loss_history = literal_eval(train_loss_history_str)
train_acc_history = literal_eval(train_acc_history_str)

In [None]:
# finding the lowest loss and highest accuracy for each region
for key, value in train_loss_history.items():
    train_loss_history[key] = min(value)
    
for key, value in train_acc_history.items():
    train_acc_history[key] = max(value)

# Sorting regions based on loss and accuracy
train_loss_history = dict(sorted(train_loss_history.items(), key=lambda x:x[1], reverse=False))
train_acc_history = dict(sorted(train_acc_history.items(), key=lambda x:x[1], reverse=True))

In [None]:
fig, ax = plt.subplots(figsize=(25, 10), dpi=300)

x = []
y = []
l = []

for i, (key, value) in enumerate(train_loss_history.items()):
    x.append(str(i))
    y.append(value)
    l.append(f'{i}: {region_comb[key]}')
    
my_cmap = plt.cm.get_cmap('viridis')
colors = my_cmap(np.linspace(0, 1, 28))

bar = ax.bar(x=x,
             height=y,
             width=0.6,
             color=colors,
             edgecolor='black')

ax.set_xlabel('Model', fontsize=20)
ax.set_ylabel('Training Loss', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.legend(handles=bar, labels=l, fontsize=15, loc='center left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(25, 10), dpi=300)

x = []
y = []
l = []

for i, (key, value) in enumerate(train_acc_history.items()):
    x.append(str(i))
    y.append(value)
    l.append(f'{i}: {region_comb[key]}')

my_cmap = plt.cm.get_cmap('viridis')
colors = my_cmap(np.linspace(0, 1, 28))

bar = ax.bar(x=x,
             height=y,
             width=0.6,
             color=colors,
             edgecolor='black')

ax.set_xlabel('Model', fontsize=20)
ax.set_ylabel('Training Accuracy (Region Importance)', fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.legend(handles=bar, labels=l, fontsize=15, loc='center left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()