In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

# Dataset and Dataloader Setup

In [None]:
from matplotlib import pyplot as plt

from dataloader.dataset import ADNI
from dataloader.dataloader import ADNILoader

image_size = 384

train_ds = ADNI("../Data/Training/", do_resize=True, do_augmentation=True, image_size=image_size)
valid_ds = ADNI("../Data/Validation/", do_resize=True, do_augmentation=False, image_size=image_size)
test_ds = ADNI("../Data/Test/", do_resize=True, do_augmentation=False, image_size=image_size)

idx = 0
image = train_ds[idx][0]
label = train_ds[idx][1]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");
    axes[i].set_title(f"Along axis {i}")

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 10
valid_batch_size = 5
test_batch_size = 5

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

In [None]:
fig, axes = plt.subplots(figsize=(train_batch_size * 2, 2), ncols=train_batch_size, dpi=300)
for i in range(train_batch_size):
    image = batch[0][i]
    label = id2label[batch[1][i].item()]
    axes[i].imshow(image.permute(1, 2, 0)[:, :, 0])
    axes[i].set_title(f"Label: {label}")
    axes[i].axis("off");

# Model Setup and Training

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW, RMSprop
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import Accuracy

from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig

from copy import deepcopy

In [None]:
vit_config = ViTConfig(num_hidden_layers=7,
                       hidden_size=600,
                       intermediate_size=256,
                       num_attention_heads=12,
                       image_size=image_size,
                       num_labels=3,
                       output_attentions=True,
                       hidden_dropout_prob=0.1,
                      )

In [None]:
class ViT(nn.Module):
    def __init__(self, num_labels=3):
        super(ViT, self).__init__()
        self.vit = ViTForImageClassification(vit_config)
        
    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits, outputs.attentions

In [None]:
model = ViT(num_labels=3)    
feature_extractor = ViTFeatureExtractor(do_resize=False, size=image_size)
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=3e-2)
criterion = nn.CrossEntropyLoss()
accuracy = Accuracy()
writer = SummaryWriter()
scheduler = ExponentialLR(optimizer, gamma=0.999)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
model.to(device);

In [None]:
epochs = 200
train_accs = []
valid_accs = []
train_losses = []
valid_losses = []
best_loss = 100
best_acc = 0
saved = False

for epoch in range(epochs):
    print(f"Epoch: {(epoch+1):02}/{epochs}")
    for step, (x, y) in enumerate(train_dataloader):
        x = np.split(np.array(x), train_batch_size)
        for i in range(len(x)):
            x[i] = np.squeeze(x[i])
        x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
        x, y  = x.to(device), y.to(device)
        logits, _ = model(x)
        loss = criterion(logits, y)
        preds = logits.argmax(1)
        acc = accuracy(y.cpu(), preds.cpu())
        optimizer.zero_grad()           
        loss.backward()                 
        optimizer.step()
        train_losses.append(loss.item())
        train_accs.append(acc.item())
    
        if (step % 10 == 0) or (step == len(train_dataloader)):
            train_loss = sum(train_losses)/len(train_losses)
            train_acc = sum(train_accs)/len(train_accs)
            writer.add_scalar('train_loss', train_loss, epoch * len(train_dataloader) + step)
            writer.add_scalar('train_acc', train_acc, epoch * len(train_dataloader) + step)
            train_losses.clear()
            train_accs.clear()
            
            model.eval() 
            with torch.no_grad():
                for x, y in valid_dataloader:
                    x = np.split(np.array(x), valid_batch_size)
                    for i in range(len(x)):
                        x[i] = np.squeeze(x[i])
                    x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
                    x, y  = x.to(device), y.to(device)
                    logits, _ = model(x)
                    loss = criterion(logits, y)
                    preds = logits.argmax(1)
                    acc = accuracy(y.cpu(), preds.cpu())
                    valid_losses.append(loss.item())
                    valid_accs.append(acc.item())
            
            valid_loss = sum(valid_losses)/len(valid_losses)
            valid_acc = sum(valid_accs)/len(valid_accs)
            writer.add_scalar('valid_loss', valid_loss, epoch * len(train_dataloader) + step)
            writer.add_scalar('valid_acc', valid_acc, epoch * len(train_dataloader) + step)
            valid_losses.clear()
            valid_accs.clear()
            
            if best_loss > valid_loss:
                best_loss = valid_loss
                best_model_loss = deepcopy(model.state_dict())
                saved = True
                
            if best_acc < valid_acc:
                best_acc = valid_acc
                best_model_acc = deepcopy(model.state_dict())
                saved = True
                
            if saved:
                print(f"Model saved: Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f}), Validation Loss(Accuracy): {valid_loss:.2f}({valid_acc:.2f})")
                saved = False

            model.train()
    
    scheduler.step()
    
    print("=" * 87)

# Save and Load Model

In [None]:
from utils.utils import save_model
save_model(best_model_loss, "Best model/", "best_model_2_loss.pt")
save_model(best_model_loss, "Best model/", "best_model_2_acc.pt")

In [None]:
model.load_state_dict(torch.load("Best model/best_model_1_loss.pt"))

# Number of Trainable Parameters

In [None]:
from utils.utils import count_parameters
print(count_parameters(model))

# Evaluation

In [None]:
from utils.report import report
report(model, feature_extractor, valid_dataloader, device)

# Attention Map Visualization

In [None]:
from utils.visualize_attention import visualize_attention

image = train_ds[0][0]
_, att_mat = model(image.unsqueeze(0).to(device))
visualize_attention(image, att_mat, device)