In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
# To have access to moduels
import sys,os
sys.path.append(os.path.dirname(os.path.realpath('')))

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import Accuracy

from torchvision.transforms import Compose, GaussianBlur, RandomRotation, RandomChoice, RandomApply, RandomAffine, Resize
from dataloader.transforms import GaussianNoise

from copy import deepcopy

from colorama import Fore

from matplotlib import pyplot as plt

from dataloader.dataset import ADNI3Channels
from dataloader.dataloader import ADNILoader

from utils.utils import count_parameters, save_model
from utils.report import cnn_report

# Dataset and Dataloader Setup

In [None]:
image_size = (384, 384)
resize = Resize(size=image_size)
gaussian_blur = GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2))
gaussian_noise = GaussianNoise(mean=0, std=0.001)
random_rotation = RandomRotation(degrees=3)
random_translate = RandomAffine(degrees=0, translate=(0.01, 0.01))
random_choice = RandomChoice([random_rotation, random_translate])
random_transforms = RandomApply([random_choice], p=0.2)

train_transforms = Compose([resize, random_transforms])
valid_transforms = Compose([resize])
test_transforms = Compose([resize])

In [None]:
train_ds = ADNI3Channels("../Data/Training/", transforms=train_transforms)
valid_ds = ADNI3Channels("../Data/Validation/", transforms=valid_transforms)
test_ds = ADNI3Channels("../Data/Test/", transforms=test_transforms)

In [None]:
idx = 0
image, label = train_ds[idx]

print("Image shape:", image.shape)
print("Label:", label.item())

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(valid_ds))
print("Number of test samples:", len(test_ds), "\n")

fig, axes = plt.subplots(ncols=3, figsize=(6, 2), dpi=300)
for i in range(3):
    axes[i].imshow(image[i, :, :])
    axes[i].axis("off");
    # print(image[i, :, :].min(), image[i, :, :].max())

In [None]:
id2label = {0: "CN", 1: "MCI", 2: "AD"}
label2id = {"CN": 0, "MCI": 1, "AD": 2}

print(id2label[label.item()])

In [None]:
train_batch_size = 5
valid_batch_size = 2
test_batch_size = 2

hparams = {'train_ds': train_ds,
           'valid_ds': valid_ds,
           'test_ds': test_ds,
           'train_batch_size': train_batch_size,
           'valid_batch_size': valid_batch_size,
           'test_batch_size': test_batch_size,
           'num_workers': 20
          }

train_dataloader = ADNILoader(**hparams).train_dataloader()
valid_dataloader= ADNILoader(**hparams).validation_dataloader()
test_dataloader = ADNILoader(**hparams).test_dataloader()

batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

# Model Development

In [None]:
class CNN(nn.Module):
    def __init__(self, num_labels=3):
        super(CNN, self).__init__()
        self.num_labels = num_labels
        
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3)),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3)),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(3, 3)),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.MaxPool2d(kernel_size=(3, 3)),
            
            nn.Flatten(),
            
            nn.Linear(204800, 1024),
            nn.BatchNorm1d(num_features=1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 3),
        )
        
    def forward(self, x):
        return self.model(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

model = CNN(3).to(device)

optimizer = Adam(model.parameters(), lr=1e-4)

class_0_freq = 140
class_1_freq = 160
class_2_freq = 160
weight = torch.tensor([1/class_0_freq, 1/class_1_freq, 1/class_2_freq]).to(device)
criterion = nn.CrossEntropyLoss(weight)

accuracy = Accuracy(num_classes=3)
writer = SummaryWriter()
scheduler = ExponentialLR(optimizer, gamma=0.999)

In [None]:
epochs = 200
train_accs = []
valid_accs = []
train_losses = []
valid_losses = []
best_loss = 100
best_acc = 0
saved = False

for epoch in range(epochs):
    print(Fore.YELLOW + f"Epoch: {(epoch+1):02}/{epochs}")
    for step, (x, y) in enumerate(train_dataloader):
        x, y  = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        preds = logits.argmax(1)
        acc = accuracy(y.cpu(), preds.cpu())
        optimizer.zero_grad()           
        loss.backward()                 
        optimizer.step()
        train_losses.append(loss.item())
        train_accs.append(acc.item())
    
        if (step % 20 == 0) or (step == len(train_dataloader)):
            train_loss = sum(train_losses)/len(train_losses)
            train_acc = sum(train_accs)/len(train_accs)
            writer.add_scalar('train_loss', train_loss, epoch * len(train_dataloader) + step)
            writer.add_scalar('train_acc', train_acc, epoch * len(train_dataloader) + step)
            train_losses.clear()
            train_accs.clear()
            
            model.eval() 
            with torch.no_grad():
                for x, y in valid_dataloader:
                    x, y  = x.to(device), y.to(device)
                    logits = model(x)
                    loss = criterion(logits, y)
                    preds = logits.argmax(1)
                    acc = accuracy(y.cpu(), preds.cpu())
                    valid_losses.append(loss.item())
                    valid_accs.append(acc.item())
            
            valid_loss = sum(valid_losses)/len(valid_losses)
            valid_acc = sum(valid_accs)/len(valid_accs)
            writer.add_scalar('valid_loss', valid_loss, epoch * len(train_dataloader) + step)
            writer.add_scalar('valid_acc', valid_acc, epoch * len(train_dataloader) + step)
            valid_losses.clear()
            valid_accs.clear()
            
            if best_loss > valid_loss:
                best_loss = valid_loss
                best_model_loss = deepcopy(model.state_dict())
                saved = True
                
            if best_acc < valid_acc:
                best_acc = valid_acc
                best_model_acc = deepcopy(model.state_dict())
                saved = True
                
            if saved:
                print(Fore.GREEN + f"Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f}), Validation Loss(Accuracy): {valid_loss:.2f}({valid_acc:.2f})")
                saved = False
            else:
                print(Fore.RED + f"Training Loss(Accuracy): {train_loss:.2f}({train_acc:.2f}), Validation Loss(Accuracy): {valid_loss:.2f}({valid_acc:.2f})")

            model.train()
    
    scheduler.step()
    
    print(Fore.YELLOW + "=" * 74)

# Model Save and Load

In [None]:
save_model(best_model_loss, "Best models/", "best_model_2D_loss.pt")

In [None]:
model.load_state_dict(torch.load("Best models/best_model_2D_loss.pt"))

# Evaluation

In [None]:
cnn_report(model, valid_dataloader, device)

In [None]:
cnn_report(model, test_dataloader, device)