# Voxel Based Networks

As a part of the ablation study, we train 3 voxel based models to compare their performances to the patch-based architectures:
- A voxel based model trained on MRI scans
- A voxel based model trained on PET scans
- A voxel based multimodal model trained on the combined feature maps of above 2 models

3D ResNet architecture code sourced from https://github.com/kenshohara/3D-ResNets-PyTorch

# Importing Libraries

In [1]:
# importing the libraries
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import os
import gc
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import precision_score, recall_score, accuracy_score
from sklearn.preprocessing import StandardScaler


import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import *
import nibabel as nb
import torchio as tio

from models import ResNetV2




## Read in images

In [2]:
# Subject IDs who are progressive normal cognition
PNC = pd.read_pickle('PNC.pkl')

# Subject IDs who are stable normal cognition
SNC = pd.read_pickle('SNC.pkl')

In [3]:
# Create datasets

def read_image_data(input_path):
    X_train = []
    y_train = []
    ids = []
    for filename in sorted(os.listdir(input_path)):
        file = os.path.join(input_path, filename)
        img_file = nb.load(file)
        img = img_file.get_fdata()
        X_train.append(img)
        ids.append(filename)
        # Progressive normal cognition target class 1
        if filename[0:8] in np.array(PNC):
            y_train.append(1)
        # Stable normal cognition class 0
        else:
            y_train.append(0)

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    # Reshape X_train to include channel dimension
    X_train = X_train.reshape(X_train.shape + (1,))
    return X_train, y_train, ids

X_train_MRI, y_train_MRI, ids_MRI = read_image_data("E:/Work/Processed_MRI/2.MNI_Registration")
X_train_PET, y_train_PET, ids_PET = read_image_data("E:/Work/Processed_PIB/4.MNI_Registered")


## Creating training and test sets for images

In [4]:
# 0.6 - 0.2 - 0.2 train-val-test split
X_train_MRI, X_test_MRI, X_train_PET, X_test_PET, ids_train, ids_test, y_train, y_test = train_test_split(X_train_MRI, 
                                                                                     X_train_PET, ids_MRI, 
                                                                                     y_train_MRI, test_size=0.2, 
                                                                                     random_state=101, stratify=y_train_MRI)

X_train_MRI, X_val_MRI, X_train_PET, X_val_PET, ids_train, ids_val, y_train, y_val = train_test_split(X_train_MRI, X_train_PET, 
                                                                                      ids_train, y_train, 
                                                                                      test_size=0.25, random_state=101,
                                                                                      stratify=y_train)

In [5]:
# Convert datasets to tensor format, with channel first
train_x_MRI = torch.from_numpy(X_train_MRI).float().permute(0,4,1,2,3)
train_x_PET = torch.from_numpy(X_train_PET).float().permute(0,4,1,2,3)
train_y = torch.from_numpy(y_train).float()

val_x_MRI = torch.from_numpy(X_val_MRI).float().permute(0,4,1,2,3)
val_x_PET = torch.from_numpy(X_val_PET).float().permute(0,4,1,2,3)
val_y = torch.from_numpy(y_val).float()

test_x_MRI = torch.from_numpy(X_test_MRI).float().permute(0,4,1,2,3)
test_x_PET = torch.from_numpy(X_test_PET).float().permute(0,4,1,2,3)
test_y = torch.from_numpy(y_test).float()

## Perform data augmentation to increase training set size

In [14]:
def perform_augmentation(dataset, seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Define transformations
    training_transform = tio.Compose([
        tio.RandomAffine(),
        tio.RandomElasticDeformation(),
        tio.RandomFlip()
    ])
    augmented_dataset = torch.clone(dataset) 
    for i in range(len(augmented_dataset)):
        augmented_dataset[i] = training_transform(augmented_dataset[i])
    return augmented_dataset

orig_train_x_MRI = torch.clone(train_x_MRI)
orig_train_x_PET = torch.clone(train_x_PET)
orig_train_y = torch.clone(train_y)

for seed in [1,101,42]:
    # Apply transformations and create augmented training set
    augmented_train_MRI = perform_augmentation(orig_train_x_MRI, seed)
    augmented_train_PET = perform_augmentation(orig_train_x_PET, seed)
    
    
    # Concatenate training and augmented training datasets
    train_x_MRI = torch.cat((train_x_MRI, augmented_train_MRI), 0)
    train_x_PET = torch.cat((train_x_PET, augmented_train_PET), 0)
    train_y = torch.cat((train_y, orig_train_y), 0)



In [6]:
# Load preaugmented data (for reruns)
train_x_MRI = torch.load("train_x_MRI.pkl")
train_x_PET = torch.load("train_x_PET.pkl")
train_y = torch.load("train_y.pkl")

In [7]:
# Create Pytorch train and test sets
train_MRI = torch.utils.data.TensorDataset(train_x_MRI,train_y)
train_PET = torch.utils.data.TensorDataset(train_x_PET,train_y)

val_MRI = torch.utils.data.TensorDataset(val_x_MRI,val_y)
val_PET = torch.utils.data.TensorDataset(val_x_PET,val_y)

test_MRI = torch.utils.data.TensorDataset(test_x_MRI,test_y)
test_PET = torch.utils.data.TensorDataset(test_x_PET,test_y)


# Create data loader
batch_size = 16
train_loader_MRI = torch.utils.data.DataLoader(train_MRI, batch_size = batch_size, shuffle = True)
train_loader_PET = torch.utils.data.DataLoader(train_PET, batch_size = batch_size, shuffle = True)

val_loader_MRI = torch.utils.data.DataLoader(val_MRI, batch_size = batch_size, shuffle = False)
val_loader_PET= torch.utils.data.DataLoader(val_PET, batch_size = batch_size, shuffle = False)

test_loader_MRI = torch.utils.data.DataLoader(test_MRI, batch_size = batch_size, shuffle = False)
test_loader_PET = torch.utils.data.DataLoader(test_PET, batch_size = batch_size, shuffle = False)

In [8]:
# Define device being used for model training
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else: 
    dev = "cpu" 
device = torch.device(dev) 

# Train  Models

In [10]:
def train_model(model, train_loader, optimiser, error, scheduler, feature_map = False):
    
    total_train_loss = 0
    for image, labels in train_loader:
        image = image.to(device)
        labels = labels.to(device)
        
        # Clear gradients
        optimiser.zero_grad()
        
        # Forward propagation
        # feature_map indicates whether the feature map prior final layer is also included in the outputs of the model
        if feature_map:
            outputs = model(image)[1]
        else:
            outputs = model(image)
            
        # Calculate loss
        loss = error(outputs.flatten(), labels)
        total_train_loss += loss.item()
        
        # Calculating gradients
        loss.backward()

        # Update parameters
        optimiser.step()
        scheduler.step()
        
        # Clear GPU Cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # Calculate accuracy
        preds = outputs.flatten().round()
    print("Average Training Loss", total_train_loss/len(train_loader.dataset))

def validate_model(model, val_loader, error, feature_map = False):
    correct_predictions_val = 0
    total_val_loss = 0
    with torch.no_grad():
        for image, labels in val_loader:
            image = image.to(device)
            labels = labels.to(device)

            # Forward propagation
            if feature_map:
                pred = model(image)[1]
            else:
                pred = model(image)
            
            # Calculate loss
            loss = error(pred.flatten(), labels)
            total_val_loss += loss.item()

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            preds = pred.flatten().round()
            correct_predictions_val += torch.sum(preds == labels).item()
        print("Validation Accuracy:", correct_predictions_val/len(val_loader.dataset))
        print("Average Validation Loss:", total_val_loss/len(val_loader.dataset))
    return total_val_loss/len(val_loader.dataset)
        
def evaluate_model(model, test_loader, error, feature_map = False):
    correct_predictions_test = 0
    preds = []
    labels = []
    with torch.no_grad():
        for test,label in test_loader:
            test = test.to(device)
            label = label.to(device)
            labels.append(label.cpu())
            
            # Forward propagation
            if feature_map:
                pred = model(test)[1]
            else:
                pred = model(test)

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            pred = pred.flatten().round()
            preds.append(pred.cpu())
            correct_predictions_test += torch.sum(pred == label).item()
    print("Test Accuracy:", correct_predictions_test/len(test_loader.dataset))    
    print("Recall:", recall_score(tborch.cat(labels), torch.cat(preds)))
    print("Precision:", precision_score(torch.cat(labels), torch.cat(preds)))
    print(preds)
    print(labels)

### MRI MODEL

In [11]:
#Instantiate ResNet Model
resnet = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0)
MRI_model = resnet.to(device)


# Set Seeds for deterministic results
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed_all(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# Define optimiser
optimiser = SGD(MRI_model.parameters(), lr=0.001,momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)

# Validation hyperparamters for early stopping
best_val_loss = np.inf
patience = 15
no_improvement = 0

for epoch in range(500):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    MRI_model.train()
    train_model(MRI_model, train_loader_MRI, optimiser, error, scheduler, True)
    
    MRI_model.eval()
    val_loss = validate_model(MRI_model, val_loader_MRI, error, True)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if val_loss <= best_val_loss:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(MRI_model.state_dict(), "./trained_models_2/MRI")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
MRI_model.load_state_dict(torch.load("./trained_models_2/MRI"))
MRI_model.eval()
evaluate_model(MRI_model, test_loader_MRI, error, True)

----------------------------EPOCH 0-------------------------------
Average Training Loss 0.04420099248651598
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.051072308203069176
----------------------------EPOCH 1-------------------------------
Average Training Loss 0.04392788375987381
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.050936386352632104
----------------------------EPOCH 2-------------------------------
Average Training Loss 0.043687310375151084
Validation Accuracy: 0.4634146341463415
Average Validation Loss: 0.0506434440612793
----------------------------EPOCH 3-------------------------------
Average Training Loss 0.04355699300277428
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05056708469623473
----------------------------EPOCH 4-------------------------------
Average Training Loss 0.04344169038241027
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05047665427370769
----------------------------

Average Training Loss 0.0374401257541336
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05154989259998973
----------------------------EPOCH 43-------------------------------
Average Training Loss 0.0368335806444043
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.050244114747861536
BEST VAL LOSS: 0.049849315387446705
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.6097560975609756
Recall: 0.7619047619047619
Precision: 0.5925925925925926
[tensor([1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1.]), tensor([1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0.]), tensor([0., 1., 0., 1., 1., 1., 0., 1., 0.])]
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]), tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]


In [10]:
# Print which subjects were correctly classified by the MRI model
y = [torch.tensor([1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1.]),
     torch.tensor([1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0.]),
     torch.tensor([0., 1., 0., 1., 1., 1., 0., 1., 0.])]


predictions = [torch.tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]),
               torch.tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]),
               torch.tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
y = torch.cat(y)
predictions = torch.cat(predictions)
correct = y==predictions
for i in range(41):
    if correct[i]:
        print(ids_test[i])

OAS30314_MR_d.nii.gz
OAS30895_MR_d.nii.gz
OAS30830_MR_d.nii.gz
OAS31059_MR_d.nii.gz
OAS30079_MR_d.nii.gz
OAS30986_MR_d.nii.gz
OAS30929_MR_d.nii.gz
OAS30241_MR_d.nii.gz
OAS30032_MR_d.nii.gz
OAS30224_MR_d.nii.gz
OAS30921_MR_d.nii.gz
OAS30614_MR_d.nii.gz
OAS31092_MR_d.nii.gz
OAS31162_MR_d.nii.gz
OAS30041_MR_d.nii.gz
OAS31168_MR_d.nii.gz
OAS30028_MR_d.nii.gz
OAS30015_MR_d.nii.gz
OAS30735_MR_d.nii.gz
OAS31103_MR_d.nii.gz
OAS30749_MR_d.nii.gz
OAS30991_MR_d.nii.gz
OAS30223_MR_d.nii.gz
OAS30572_MR_d.nii.gz
OAS31087_MR_d.nii.gz


### PET MODEL

In [12]:
#Instantiate ResNet Model
resnet = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0)

PET_model = resnet.to(device)

# Set Seeds for deterministic results
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# Define optimiser
optimiser = SGD(PET_model.parameters(), lr=0.001,momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)

# Validation hyperparamters for early stopping
best_val_loss = np.inf
patience = 15
no_improvement = 0

for epoch in range(500):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    PET_model.train()
    train_model(PET_model, train_loader_PET, optimiser, error, scheduler, True)
    
    PET_model.eval()
    avg_val_loss = validate_model(PET_model, val_loader_PET, error, True)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(PET_model.state_dict(), "./trained_models_2/PET")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
PET_model.load_state_dict(torch.load("./trained_models_2/PET"))
PET_model.eval()
evaluate_model(PET_model, test_loader_PET, error, True)

----------------------------EPOCH 0-------------------------------
Average Training Loss 0.044133135285533844
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05137293949359801
----------------------------EPOCH 1-------------------------------
Average Training Loss 0.04359371923520917
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05120425980265548
----------------------------EPOCH 2-------------------------------
Average Training Loss 0.043256045975646036
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.05049649099024331
----------------------------EPOCH 3-------------------------------
Average Training Loss 0.04295000323995215
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.0504597032942423
----------------------------EPOCH 4-------------------------------
Average Training Loss 0.04275460529034255
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.050282975522483266
----------------------------

Average Training Loss 0.03689555145922254
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04905825562593413
----------------------------EPOCH 43-------------------------------
Average Training Loss 0.036685504691034064
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04840194015968137
----------------------------EPOCH 44-------------------------------
Average Training Loss 0.03719921636044002
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04833570776916132
----------------------------EPOCH 45-------------------------------
Average Training Loss 0.036640461534261703
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04829379552748145
----------------------------EPOCH 46-------------------------------
Average Training Loss 0.03598084264114255
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04825608468637234
----------------------------EPOCH 47-------------------------------
Average Training Loss 0

Average Training Loss 0.02812311889939621
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.046726379452682126
----------------------------EPOCH 85-------------------------------
Average Training Loss 0.02817716174682633
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05128361248388523
----------------------------EPOCH 86-------------------------------
Average Training Loss 0.02824559874954771
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047338497347948025
----------------------------EPOCH 87-------------------------------
Average Training Loss 0.027765809512529216
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04554680644012079
----------------------------EPOCH 88-------------------------------
Average Training Loss 0.02724554707280925
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04630330132275093
----------------------------EPOCH 89-------------------------------
Average Training Loss 

In [11]:
# Print which subjects were correctly classified by the PET model
y = [torch.tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0.]),
     torch.tensor([0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0.]),
     torch.tensor([0., 0., 0., 1., 1., 1., 1., 1., 1.])]


predictions = [torch.tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]),
               torch.tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]),
               torch.tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
y = torch.cat(y)
predictions = torch.cat(predictions)
correct = y==predictions
for i in range(41):
    if correct[i]:
        print(ids_test[i])

OAS30314_MR_d.nii.gz
OAS30146_MR_d.nii.gz
OAS30895_MR_d.nii.gz
OAS30830_MR_d.nii.gz
OAS31059_MR_d.nii.gz
OAS30079_MR_d.nii.gz
OAS30986_MR_d.nii.gz
OAS30929_MR_d.nii.gz
OAS31125_MR_d.nii.gz
OAS30032_MR_d.nii.gz
OAS30224_MR_d.nii.gz
OAS30823_MR_d.nii.gz
OAS30907_MR_d.nii.gz
OAS30921_MR_d.nii.gz
OAS31344_MR_d.nii.gz
OAS30101_MR_d.nii.gz
OAS30769_MR_d.nii.gz
OAS30933_MR_d.nii.gz
OAS30959_MR_d.nii.gz
OAS30028_MR_d.nii.gz
OAS30015_MR_d.nii.gz
OAS30735_MR_d.nii.gz
OAS30160_MR_d.nii.gz
OAS31103_MR_d.nii.gz
OAS30749_MR_d.nii.gz
OAS30991_MR_d.nii.gz
OAS30223_MR_d.nii.gz
OAS30964_MR_d.nii.gz
OAS30572_MR_d.nii.gz


### Multimodal PET+MRI model

The model below combines the feature maps of the pretrained MRI and PET models for classification.

In [9]:
# Reload weights for rerunning
models = []
pet_model = ResNetV2.generate_model(
    model_depth=10,
    n_classes=1,
    n_input_channels=1,
    shortcut_type='B',
    conv1_t_size=7,
    conv1_t_stride=2,
    no_max_pool=False,
    widen_factor=1.0).to(device)
pet_model.load_state_dict(torch.load(f"./trained_models_2/PET"))

mri_model = ResNetV2.generate_model(
    model_depth=10,
    n_classes=1,
    n_input_channels=1,
    shortcut_type='B',
    conv1_t_size=7,
    conv1_t_stride=2,
    no_max_pool=False,
    widen_factor=1.0).to(device)
mri_model.load_state_dict(torch.load(f"./trained_models_2/MRI"))

models.append(pet_model)
models.append(mri_model)

In [None]:
# Define multimodal network
# This model uses the feature maps of the individual PET and MRI models as inputs
class multimodal_CNN(nn.Module):
    def __init__(self, image_models):
        super().__init__()
        self.image_models = nn.ModuleList(image_models)
        self.drop= nn.Dropout(p=0.4)
        
        self.fc3 = nn.Linear(200, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        image_outputs = []
        for i in range(2):
            image_output = self.image_models[i](x[i])[0]  
            image_outputs.append(self.drop(image_output))
        x = torch.cat(image_outputs, dim=1)
      
        x = self.sigmoid(self.fc3(x))
        return x

In [11]:
# Create Pytorch train and test sets
train_MRI_PET = torch.utils.data.TensorDataset(train_x_MRI,train_x_PET,train_y)
val_MRI_PET = torch.utils.data.TensorDataset(val_x_MRI,val_x_PET,val_y)
test_MRI_PET = torch.utils.data.TensorDataset(test_x_MRI,test_x_PET,test_y)



# Create data loader
batch_size = 16
train_loader_MRI_PET = torch.utils.data.DataLoader(train_MRI_PET, batch_size = batch_size, shuffle = True)
val_loader_MRI_PET = torch.utils.data.DataLoader(val_MRI_PET, batch_size = batch_size, shuffle = False)
test_loader_MRI_PET = torch.utils.data.DataLoader(test_MRI_PET, batch_size = batch_size, shuffle = False)


In [12]:
def train_model(model, train_loader, optimiser, error, scheduler, feature_maps=False):
    total_train_loss = 0
    for MRI_image, PET_image, labels in train_loader:
        MRI_image = MRI_image.to(device)
        PET_image = PET_image.to(device)
        labels = labels.to(device)
        # Clear gradients
        optimiser.zero_grad()
        
        # Forward propagation
        if feature_maps:
            outputs = model((PET_image, MRI_image))[1]
        else:
            outputs = model((PET_image, MRI_image))
        
        # Calculate loss
        loss = error(outputs.flatten(), labels)
        total_train_loss += loss.item()
        
        # Calculating gradients
        loss.backward()

        # Update parameters
        optimiser.step()
        scheduler.step()
        
        # Clear GPU Cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # Calculate accuracy
        preds = outputs.flatten().round()
    print("Average Training Loss", total_train_loss/len(train_loader.dataset))

def validate_model(model, val_loader, error, feature_maps=False):
    correct_predictions_val = 0
    total_val_loss = 0
    with torch.no_grad():
        for MRI_image, PET_image, labels in val_loader:
            MRI_image = MRI_image.to(device)
            PET_image = PET_image.to(device)
            labels = labels.to(device)

            # Forward propagation
            if feature_maps:
                pred = model((PET_image, MRI_image))[1]
            else:
                pred = model((PET_image, MRI_image))
            
            # Calculate softmax and cross entropy loss
            loss = error(pred.flatten(), labels)
            total_val_loss += loss.item()

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            preds = pred.flatten().round()
            #print(pred)
            correct_predictions_val += torch.sum(preds == labels).item()
        print("Validation Accuracy:", correct_predictions_val/len(val_loader.dataset))
        print("Average Validation Loss:", total_val_loss/len(val_loader.dataset))
    return total_val_loss/len(val_loader.dataset)
        
def evaluate_model(model, test_loader, error, feature_maps=False):
    correct_predictions_test = 0
    preds = []
    labels = []
    with torch.no_grad():
        for MRI_image, PET_image, label in test_loader:
            MRI_image = MRI_image.to(device)
            PET_image = PET_image.to(device)
            
            labels.append(label.cpu())
            
            # Forward propagation
            if feature_maps:
                pred = model((PET_image, MRI_image))[1]
            else:
                pred = model((PET_image, MRI_image))

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            pred = pred.flatten().round()
            preds.append(pred.cpu())
            correct_predictions_test += torch.sum(pred.cpu() == label.cpu()).item()
    print("Test Accuracy:", correct_predictions_test/len(test_loader.dataset))    
    print("Recall:", recall_score(torch.cat(labels), torch.cat(preds)))
    print("Precision:", precision_score(torch.cat(labels), torch.cat(preds)))
    print(preds)
    print(labels)

In [17]:
multimodal_model = multimodel_CNN(models).to(device)

# Freeze layers of PET and MRI models
for image_models in multimodal_model.image_models:
    for name, param in image_models.named_parameters():
        param.requires_grad = False 

# Set Seeds for deterministic results
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# Define optimiser
optimiser = SGD(multimodal_model.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)

# Validation hyperparameters for early stopping
best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(500):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    multimodal_model.train()
    train_model(multimodal_model, train_loader_MRI_PET, optimiser, error, scheduler)
    
    multimodal_model.eval()
    avg_val_loss = validate_model(multimodal_model, val_loader_MRI_PET, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(multimodal_model.state_dict(), "./trained_models_2/MRI_PET_NO_PATCH")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
multimodal_model.load_state_dict(torch.load("./trained_models_2/MRI_PET_NO_PATCH"))
multimodal_model.eval()
evaluate_model(multimodal_model, test_loader_MRI_PET, error)

----------------------------EPOCH 0-------------------------------
Average Training Loss 0.04293800157601716
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05021972336420199
----------------------------EPOCH 1-------------------------------
Average Training Loss 0.04216503706134734
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04992436781162169
----------------------------EPOCH 2-------------------------------
Average Training Loss 0.04181580807341904
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.049638108509342846
----------------------------EPOCH 3-------------------------------
Average Training Loss 0.04079483070823013
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.049363178450886794
----------------------------EPOCH 4-------------------------------
Average Training Loss 0.040081167196641204
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.049117443038196096
--------------------------

Average Training Loss 0.0271853646416156
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.045797489038327845
----------------------------EPOCH 43-------------------------------
Average Training Loss 0.026774949959067047
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04555452742227694
----------------------------EPOCH 44-------------------------------
Average Training Loss 0.0270273392317725
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.045466809737973096
----------------------------EPOCH 45-------------------------------
Average Training Loss 0.026903593821115183
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.0454992462949055
----------------------------EPOCH 46-------------------------------
Average Training Loss 0.025727445595577113
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0456531542103465
----------------------------EPOCH 47-------------------------------
Average Training Loss 0.