In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')) + '/Modules')

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import Accuracy

from torchvision.transforms import Compose, GaussianBlur, RandomRotation, RandomChoice, RandomApply, RandomAffine
from dataloader.transforms import GaussianNoise

from copy import deepcopy

from colorama import Fore

from matplotlib import pyplot as plt

from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification

from dataloader.dataset import ADNI
from dataloader.dataloader import ADNILoader

from utils.utils import count_parameters, save_model
from utils.report import sklearn_classification_report, custom_classification_report

# Dataset and Dataloader Setup

In [None]:
image_size = (79, 95)

gaussian_blur = GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2))
gaussian_noise = GaussianNoise(mean=0, std=0.1)
random_rotation = RandomRotation(degrees=5)
random_translate = RandomAffine(degrees=0, translate=(0.05, 0.05))
random_choice = RandomChoice([gaussian_blur,
                              gaussian_noise,
                              #random_rotation,
                              #random_translate,
                             ])

random_transforms = RandomApply([random_choice], p=0.7)

train_transforms = Compose([random_transforms])
valid_transforms = Compose([])
test_transforms = Compose([])

In [None]:
train_ds = ADNI("../Data/Training/", transforms=train_transforms)
valid_ds = ADNI("../Data/Validation/", transforms=valid_transforms)
test_ds = ADNI("../Data/Test/", transforms=test_transforms)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(nrows=6, ncols=10, figsize=(3, 2), dpi=300)
for row in range(6):
    for col in range(10):
        idx = row * 10 + col
        axes[row, col].imshow(image[idx, :, :])
        axes[row, col].axis("off");
        # print(image[idx, :, :].min(), image[idx, :, :].max())

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 15
valid_batch_size = 10
test_batch_size = 10

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20,
           'train_shuffle': True,
           'valid_shuffle': False,
           'test_shuffle': False,
           'train_drop_last': True,
           'valid_drop_last': False,
           'test_drop_last': False,
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Model Development

In [None]:
vit_config = ViTConfig(image_size=image_size,
                       patch_size=32,
                       num_labels=3,
                       num_channels=60,
                       num_hidden_layers=12,
                       hidden_size=768,
                       intermediate_size=3072,
                       num_attention_heads=12,
                       output_attentions=True,
                       hidden_dropout_prob=0,
                       attention_probs_dropout_prob=0,
                      )

In [None]:
class ViT(nn.Module):
    def __init__(self, num_labels=3):
        super(ViT, self).__init__()
        self.vit = ViTForImageClassification(vit_config)
        
    def forward(self, x):
        outputs = self.vit(x)
        return outputs.logits, outputs.attentions

In [None]:
# Selecting GPU
GPU = {0: torch.device('cuda:0'),
       1: torch.device('cuda:1'),
       2: torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      }

## Single-GPU trining
device = GPU[1]
model = ViT(num_labels=3).to(device)

## Multi-GPU training
# device = GPU[2]
# model = ViT(num_labels=3, pretrained=pretrained)
# model= nn.DataParallel(model)
# model.to(device);

feature_extractor = ViTFeatureExtractor(size=image_size,
                                        do_resize=False,
                                        do_normalize=False)


optimizer = Adam(model.parameters(), lr=5e-6, weight_decay=1e-3)

class_0_freq = 140
class_1_freq = 160
class_2_freq = 160
weight = torch.tensor([1/class_0_freq, 1/class_1_freq, 1/class_2_freq]).to(device)
criterion = nn.CrossEntropyLoss()

accuracy = Accuracy()
writer = SummaryWriter()
scheduler = ExponentialLR(optimizer, gamma=0.999)

In [None]:
epochs = 100
train_accs = []
valid_accs = []
train_losses = []
valid_losses = []
best_loss = 100
best_acc = 0
saved = False
patience = 0
early_stop = 10

for epoch in range(epochs):
    print(Fore.YELLOW + f"Epoch: {(epoch+1):02}/{epochs}")
    for step, (x, y) in enumerate(train_dataloader):
        x = np.split(np.array(x), train_batch_size)
        for i in range(len(x)):
            x[i] = np.squeeze(x[i])
        x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
        x, y  = x.to(device), y.to(device)
        logits, _ = model(x)
        criterion.weight = weight
        loss = criterion(logits, y)
        preds = logits.argmax(1)
        acc = accuracy(y.cpu(), preds.cpu())
        optimizer.zero_grad()           
        loss.backward()                 
        optimizer.step()
        train_losses.append(loss.item())
        train_accs.append(acc.item())
    
        if (step % 7 == 0) or (step == len(train_dataloader)):
            train_loss = sum(train_losses)/len(train_losses)
            train_acc = sum(train_accs)/len(train_accs)
            writer.add_scalar('train_loss', train_loss, epoch * len(train_dataloader) + step)
            writer.add_scalar('train_acc', train_acc, epoch * len(train_dataloader) + step)
            train_losses.clear()
            train_accs.clear()
            
            model.eval() 
            with torch.no_grad():
                for x, y in valid_dataloader:
                    x = np.split(np.array(x), valid_batch_size)
                    for i in range(len(x)):
                        x[i] = np.squeeze(x[i])
                    x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
                    x, y  = x.to(device), y.to(device)
                    logits, _ = model(x)
                    criterion.weight = None
                    loss = criterion(logits, y)
                    preds = logits.argmax(1)
                    acc = accuracy(y.cpu(), preds.cpu())
                    valid_losses.append(loss.item())
                    valid_accs.append(acc.item())
            
            valid_loss = sum(valid_losses)/len(valid_losses)
            valid_acc = sum(valid_accs)/len(valid_accs)
            writer.add_scalar('valid_loss', valid_loss, epoch * len(train_dataloader) + step)
            writer.add_scalar('valid_acc', valid_acc, epoch * len(train_dataloader) + step)
            valid_losses.clear()
            valid_accs.clear()
            
            if best_loss > valid_loss:
                best_loss = valid_loss
                best_model_loss = deepcopy(model.state_dict())
                saved = True
                
            if best_acc < valid_acc:
                best_acc = valid_acc
                best_model_acc = deepcopy(model.state_dict())
                saved = True
                
            if saved:
                print(Fore.GREEN + f"Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f}), Validation Loss(Accuracy): {valid_loss:.2f}({valid_acc:.2f})")
                saved = False
                patience = 0
            else:
                print(Fore.RED + f"Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f}), Validation Loss(Accuracy): {valid_loss:.2f}({valid_acc:.2f})")

            model.train()
    
    scheduler.step()
    
    print(Fore.YELLOW + "=" * 74)
    
    if patience >= early_stop:
        print("Early stop activated!")
        break
        
    patience += 1

# Save and Load Model

In [None]:
save_model(best_model_loss, "Best models/", "ViT_3D_loss.pt")
save_model(best_model_acc, "Best models/", "ViT_3D_acc.pt")

In [None]:
model.load_state_dict(torch.load("Best models/ViT_3D_acc.pt"))

# Evaluation

In [None]:
train_transforms = Compose([])

train_ds = ADNI("../Data/Training/", transforms=train_transforms)

hparams['train_ds'] = train_ds
hparams['train_shuffle'] = False

train_dataloader = ADNILoader(**hparams).train_dataloader()

In [None]:
def predict(model, dataloader, device):
    y_true = []
    y_pred = []
    
    model.eval()
    with torch.no_grad():
        for step, (x, y) in enumerate(dataloader):
            x = np.split(np.array(x), dataloader.batch_size)
            for i in range(len(x)):
                x[i] = np.squeeze(x[i])
            x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
            x, y  = x.to(device), y.to(device)
            logits, _ = model(x)
            preds = logits.argmax(1)
        
            y_pred.append(preds.cpu().numpy())
            y_true.append(y.cpu().numpy())

    y_pred = np.concatenate(y_pred, axis=0)
    y_true = np.concatenate(y_true, axis=0)
    
    return y_true, y_pred

y_true, y_pred = predict(model, valid_dataloader, device)
custom_classification_report(y_true, y_pred)
sklearn_classification_report(y_true, y_pred)
print('#'*53)

y_true, y_pred = predict(model, test_dataloader, device)
custom_classification_report(y_true, y_pred)
sklearn_classification_report(y_true, y_pred)
print('#'*53)