# Unimodal Patch Based networks

The following models utilise a patch based ensemble architecture. As a part of the ablation study, we train two unimodal patch based models, one trained on MRI images, and the other on PET images. There are two classifcation stages for the unimodal patch based model: 
- Patch Feature Extraction: 
    - MRI and PET scans are divided into 27 uniform patches of size 44x45x44. 
    - 27 ResNet models are trained on patches from each patch location to extract local features.
    - 3D ResNet architecture code was adapted from https://github.com/kenshohara/3D-ResNets-PyTorch
- Patch fusion:
    - Feature maps of the 27 patch models are concatenated and used as inputs to a final model for global classification

# Importing Libraries

In [1]:
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import os
import gc
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import precision_score, recall_score, accuracy_score
from sklearn.preprocessing import StandardScaler
from imblearn.metrics import specificity_score
 
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import *
import nibabel as nb
import torchio as tio

from models import ResNetV2




## Read in images

In [2]:
# Subject IDs who are progressive normal cognition (will develop MCI or AD within 10 years)
PNC = pd.read_pickle('PNC.pkl')

# Subject IDs who are stable normal cognition (will remain CN within 10 years)
SNC = pd.read_pickle('SNC.pkl')

In [3]:
# Create datasets

def read_image_data(input_path):
    """Reads in MRI and PET image data from specified directory, along with corresponding subject ids.

    Parameters
    ----------
    input_path : str
        The input directory of the image data and subject ids
        
    Returns
    -------
    X_train : np.array
        a numpy array containing image data
        
    y_train : np.array
        a numpy array containing labels for image data
        
    ids : np.array
        a numpy array containing subject ids of image data
    """
    X_train = []
    y_train = []
    ids = []
    for filename in sorted(os.listdir(input_path)):
        file = os.path.join(input_path, filename)
        img_file = nb.load(file)
        img = img_file.get_fdata()
        X_train.append(img)
        ids.append(filename)
        # Progressive normal cognition target class 1
        if filename[0:8] in np.array(PNC):
            y_train.append(1)
        # Stable normal cognition class 0
        else:
            y_train.append(0)

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    # Reshape X_train to include channel dimension
    X_train = X_train.reshape(X_train.shape + (1,))
    return X_train, y_train, ids

X_MRI, y_MRI, ids_MRI = read_image_data("E:/Work/Processed_MRI/2.MNI_Registration")
X_PET, y_PET, ids_PET = read_image_data("E:/Work/Processed_PIB/4.MNI_Registered")


## Creating training and test sets for images

In [4]:
# 0.6 - 0.2 - 0.2 train-val-test split
X_train_MRI, X_test_MRI, X_train_PET, X_test_PET, ids_train, ids_test, y_train, y_test = train_test_split(X_MRI, 
                                                                                     X_PET, ids_MRI, 
                                                                                     y_MRI, test_size=0.2, 
                                                                                     random_state=101, stratify=y_MRI)

X_train_MRI, X_val_MRI, X_train_PET, X_val_PET, ids_train, ids_val, y_train, y_val = train_test_split(X_train_MRI, X_train_PET, 
                                                                                      ids_train, y_train, 
                                                                                      test_size=0.25, random_state=101,
                                                                                      stratify=y_train)

In [5]:
# Convert datasets to tensor format, with channel first
train_x_MRI = torch.from_numpy(X_train_MRI).float().permute(0,4,1,2,3)
train_x_PET = torch.from_numpy(X_train_PET).float().permute(0,4,1,2,3)
train_y = torch.from_numpy(y_train).float()

val_x_MRI = torch.from_numpy(X_val_MRI).float().permute(0,4,1,2,3)
val_x_PET = torch.from_numpy(X_val_PET).float().permute(0,4,1,2,3)
val_y = torch.from_numpy(y_val).float()

test_x_MRI = torch.from_numpy(X_test_MRI).float().permute(0,4,1,2,3)
test_x_PET = torch.from_numpy(X_test_PET).float().permute(0,4,1,2,3)
test_y = torch.from_numpy(y_test).float()

## Perform data augmentation to increase training set size

In [None]:
def perform_augmentation(dataset, seed):
    """
    Performs augmentation to image data. To simulate different positions and size of the patient 
    within the scanner, and anatomical variations present in the images, random affine transformations, 
    elastic deformations and flips are applied to the images.

    Parameters
    ----------
    dataset : torch.tensor
        Pytorch tensor of image data being augmented
        
    seed : int
        Seed for random number generation
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Define transformations
    training_transform = tio.Compose([
        tio.RandomAffine(),
        tio.RandomElasticDeformation(),
        tio.RandomFlip()
    ])
    augmented_dataset = torch.clone(dataset) 
    for i in range(len(augmented_dataset)):
        augmented_dataset[i] = training_transform(augmented_dataset[i])
    return augmented_dataset

orig_train_x_MRI = torch.clone(train_x_MRI)
orig_train_x_PET = torch.clone(train_x_PET)
orig_train_y = torch.clone(train_y)
for seed in [1,101,42]:
    # Apply transformations and create augmented training set
    augmented_train_MRI = perform_augmentation(orig_train_x_MRI, seed)
    augmented_train_PET = perform_augmentation(orig_train_x_PET, seed)
    
    
    # Concatenate training and augmented training datasets
    train_x_MRI = torch.cat((train_x_MRI, augmented_train_MRI), 0)
    train_x_PET = torch.cat((train_x_PET, augmented_train_PET), 0)
    train_y = torch.cat((train_y, orig_train_y), 0)

# Removing blank space around brain to make dimensions even so we can divide image into uniform patches
train_x_MRI = train_x_MRI[:, :, 0:88, 0:108, 0:88]
train_x_PET = train_x_PET[:, :, 0:88, 0:108, 0:88]   

In [6]:
# Load augmented dataset (for reruns)
train_x_MRI = torch.load("train_x_MRI.pkl")
train_x_PET = torch.load("train_x_PET.pkl")
train_y = torch.load("train_y.pkl")

## Divide images into patches

In [7]:
def create_patches(images, dim, labels):
    """
    For each sample, divides images into 27 uniform 3x3x3 patches of size 44x54x44 with 50% overlap
    Then creates 27 training datasets for each patch location 
    For example a dataset containing patches of each subject in the top left corner
    and a dataset containing patches of each subject in the middle 

    Parameters
    ----------
    images : torch.tensor
        Pytorch tensor containing image data
        
    dim : int
        Dimension of the input tensor
    
    labels : torch.tensor
        Pytorch tensor containing image data class labels
    
    Returns
    -------
    datasets : list
        List of TensorDatasets containing patches for each patch location
    """
    # Create datasets for each patch location
    datasets = [[] for _ in range(27)]
    
    # Find the starting locations of each patch
    starting_points = []
    for height_stride in range(3):
        for width_stride in range(3):
            for depth_stride in range(3):
                    start = (0 + height_stride*dim[1]//4, 0 + width_stride*dim[2]//4, 0 + depth_stride*dim[3]//4)
                    starting_points.append(start)

    # For each image
    for i in range(len(images)):
        # Create patch from every starting point
        for j in range(len(starting_points)):
            start_pt = starting_points[j]
            patch = images[i][:, start_pt[0]:start_pt[0] + dim[1]//2, 
                                      start_pt[1]:start_pt[1] + dim[2]//2, 
                                      start_pt[2]:start_pt[2] + dim[3]//2]
            datasets[j].append(patch)
            
    # For each patch location, stack patches from every sample samples into a tensor as a Tensor Dataset
    for i in range(len(datasets)):
        datasets[i] = torch.stack(datasets[i])
        datasets[i] = torch.utils.data.TensorDataset(datasets[i],labels)
    return datasets


In [8]:
# Create datasets for each patch location

dim = train_x_MRI[0].shape
patch_train_datasets_MRI = create_patches(train_x_MRI, dim, train_y)
patch_train_datasets_PET = create_patches(train_x_PET, dim, train_y)

patch_val_datasets_MRI = create_patches(val_x_MRI, dim, val_y)
patch_val_datasets_PET = create_patches(val_x_PET, dim, val_y)

patch_test_datasets_MRI = create_patches(test_x_MRI, dim, test_y)
patch_test_datasets_PET = create_patches(test_x_PET, dim, test_y)


# Create data loaders for each patch location dataset
patch_train_dataloaders_MRI = []
patch_train_dataloaders_PET = []

patch_val_dataloaders_MRI = []
patch_val_dataloaders_PET = []

patch_test_dataloaders_MRI = []
patch_test_dataloaders_PET = []

batch_size = 16
for i in range(27):
    patch_train_dataloaders_MRI.append(torch.utils.data.DataLoader(patch_train_datasets_MRI[i], 
                                                                  batch_size = batch_size, shuffle = True))
    patch_train_dataloaders_PET.append(torch.utils.data.DataLoader(patch_train_datasets_PET[i], 
                                                                  batch_size = batch_size, shuffle = True))
    
    patch_val_dataloaders_MRI.append(torch.utils.data.DataLoader(patch_val_datasets_MRI[i], 
                                                                  batch_size = batch_size, shuffle = False))
    patch_val_dataloaders_PET.append(torch.utils.data.DataLoader(patch_val_datasets_PET[i], 
                                                                  batch_size = batch_size, shuffle = False))
    
    patch_test_dataloaders_MRI.append(torch.utils.data.DataLoader(patch_test_datasets_MRI[i], 
                                                                  batch_size = batch_size, shuffle = False))
    patch_test_dataloaders_PET.append(torch.utils.data.DataLoader(patch_test_datasets_PET[i], 
                                                                  batch_size = batch_size, shuffle = False))

In [9]:
# Define device being used for model training
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else: 
    dev = "cpu" 
device = torch.device(dev) 

# Model training

In [10]:
def train_model(model, train_loader, optimiser, error, scheduler, feature_map=False):
    """
    Trains model and performs gradient updates using training data.

    Parameters
    ----------
    model : torch.nn.Module
        The pytorch model to be trained

    train_loader : torch.utils.data.DataLoader
        Dataloader used for training set
        
    optimiser : torch.optim optimisers for pytorch e.g torch.optim.SGD
        Optimiser used for model training
        
    error : torch.nn loss functions for pytorch e.g nn.BCELoss
        Loss function used for model training
        
    scheduler : torch.optim.lr_scheduler
        Scheduler used to adjust learning rate during training
        
    feature_map : Boolean
        Flag indicating whether the input model returns both the final layer output
        and previous feature maps
    """
    total_train_loss = 0
    for image, labels in train_loader:
        image = image.to(device)
        labels = labels.to(device)
        
        # Clear gradients
        optimiser.zero_grad()
        
        # Forward propagation
        # feature_map indicates whether the feature map prior final layer is also included in the outputs of the model
        if feature_map:
            outputs = model(image)[1]
        else:
            outputs = model(image)
        
        # Calculate loss
        loss = error(outputs.flatten(), labels)
        
        total_train_loss += loss.item()
        
        # Calculating gradients
        loss.backward()

        # Update parameters
        optimiser.step()
        scheduler.step()
        
        # Clear GPU Cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # Calculate accuracy
        preds = outputs.flatten().round()
    #print("Average Training Loss", total_train_loss/len(train_loader.dataset))

def validate_model(model, val_loader, error, feature_map=False):
    """
    Tests validation set on input model

    Parameters
    ----------
    model : torch.nn.Module
        The pytorch model to be trained

    val_loader : torch.utils.data.DataLoader
        Dataloader used for validation set
        
    error : torch.nn loss functions for pytorch e.g nn.BCELoss
        Loss function used for model training
        
    feature_map : Boolean
        Flag indicating whether the input model returns both the final layer output
        and previous feature maps
        
    Returns
    -------
    validation_loss : float
        validation loss of model
    """
    correct_predictions_val = 0
    total_val_loss = 0
    with torch.no_grad():
        for image, labels in val_loader:
            image = image.to(device)
            labels = labels.to(device)

            # Forward propagation
            if feature_map:
                pred = model(image)[1]
            else:
                pred = model(image)
            
            # Calculate loss
            loss = error(pred.flatten(), labels)
            total_val_loss += loss.item()

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            preds = pred.flatten().round()
            #print(pred)
            correct_predictions_val += torch.sum(preds == labels).item()
        #print("Validation Accuracy:", correct_predictions_val/len(val_loader.dataset))
        #print("Average Validation Loss:", total_val_loss/len(val_loader.dataset))
    validation_loss = total_val_loss/len(val_loader.dataset)
    return validation_loss
        
def evaluate_model(model, test_loader, error, feature_map=False, verbose=1):
    """
    Evaluates input model on test set. Prints out the model's accuracy, true positive rate,
    true negative rate, and predictions against the true labels.

    Parameters
    ----------
    model : torch.nn.Module
        The pytorch model to be trained

    test_loader : torch.utils.data.DataLoader
        Dataloader used for test set
        
    error : torch.nn loss functions for pytorch e.g nn.BCELoss
        Loss function used for model training
        
    feature_map : Boolean
        Flag indicating whether the input model returns both the final layer output
        and previous feature maps
        
    verbpse : Boolean
        Flag indicating whether to print final test scores, or return them instead
    """
    correct_predictions_test = 0
    preds = []
    labels = []
    with torch.no_grad():
        for test,label in test_loader:
            test = test.to(device)
            label = label.to(device)
            labels.append(label.cpu())
            
            # Forward propagation
            if feature_map:
                pred = model(test)[1]
            else:
                pred = model(test)

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            pred = pred.flatten().round()
            preds.append(pred.cpu())
            correct_predictions_test += torch.sum(pred == label).item()
    if verbose==1:
        print("Test Accuracy:", correct_predictions_test/len(test_loader.dataset))    
        print("True Positive Rate:", recall_score(torch.cat(labels), torch.cat(preds)))
        print("True Negative Rate:", specificity_score(torch.cat(labels), torch.cat(preds)))
        print("Predictions:",preds)
        print("True Labels:",labels)
    else:
        return (correct_predictions_test/len(test_loader.dataset), recall_score(torch.cat(labels), torch.cat(preds)),
               specificity_score(torch.cat(labels), torch.cat(preds)))

## 1. Patch Feature Extraction

27 models are trained for each patch location for each modality

### Training MRI individual patch models

In [15]:
# Set seeds 
torch.manual_seed(101)
torch.cuda.manual_seed(101)
torch.cuda.manual_seed_all(101)
random.seed(101)
np.random.seed(101)
torch.backends.cudnn.benchmark = False

mri_patch_models = []
mri_accuracies = []
mri_true_pos_rates = []
mri_true_neg_rates = []

# For each patch location, train a ResNet model on the image patches in that location
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    
    best_val_loss = np.inf
    patience = 20
    no_improvement = 0
    
    # Binary Cross Entropy Loss
    error = nn.BCELoss()

    # SGD Optimizer
    optimiser = SGD(patch_model.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)
    
    for epoch in range(1000):
        #print(f"----------------------------EPOCH {epoch} PATCH {i}-------------------------------")
        patch_model.train()
        train_model(patch_model, patch_train_dataloaders_MRI[i], optimiser, error, scheduler, True)

        patch_model.eval()
        val_loss = validate_model(patch_model, patch_val_dataloaders_MRI[i], error, True)
        
        # Save model weights if improvement seen
        # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(patch_model.state_dict(), f"./patch_temp/PATCH_TEMP{i}")
        else:
            no_improvement += 1
            if no_improvement <= patience:
                continue
            else:
                #print(f"BEST VAL ACC: {best_val_loss}")
                break
         
    #print(f"----------------------------TEST RESULTS PATCH{i}-------------------------------")
    patch_model.load_state_dict(torch.load(f"./patch_temp/PATCH_TEMP{i}"))
    patch_model.eval()
    scores = evaluate_model(patch_model, patch_test_dataloaders_MRI[i], error, True,verbose=0)
    mri_accuracies.append(scores[0])
    mri_true_pos_rates.append(scores[1])
    mri_true_neg_rates.append(scores[2])
    mri_patch_models.append(patch_model)

In [16]:
mri_patch_scores = pd.DataFrame(list(zip([i for i in range(1,28)], mri_accuracies, mri_true_pos_rates, mri_true_neg_rates)),
              columns=['Patch Number','Test Accuracy','Test True Positive Rate', 'Test True Negative Rate'])
mri_patch_scores.set_index('Patch Number').round(2)

Unnamed: 0_level_0,Test Accuracy,Test True Positive Rate,Test True Negative Rate
Patch Number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,0.59,0.95,0.2
2,0.66,0.76,0.55
3,0.61,0.71,0.5
4,0.66,0.67,0.65
5,0.56,0.76,0.35
6,0.59,0.71,0.45
7,0.59,0.76,0.4
8,0.71,0.81,0.6
9,0.54,0.67,0.4
10,0.63,0.71,0.55


In [21]:
# Reload weights (for rerunning)
mri_models = []
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    patch_model.load_state_dict(torch.load(f"./patch_temp/PATCH_TEMP{i}"))
    mri_models.append(patch_model)
    
    

### Training PET Individual Patch Models

In [11]:
# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)


pet_patch_models = []
pet_accuracies = []
pet_true_pos_rates = []
pet_true_neg_rates = []

# For each patch location, train a resnet model on the image patches in that location
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    
    
    best_val_loss = np.inf
    patience = 20
    no_improvement = 0
    
    # Binary Cross Entropy Loss
    error = nn.BCELoss()

    # SGD Optimizer
    optimiser = SGD(patch_model.parameters(), lr=0.001, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)
    
    for epoch in range(1000):
        #print(f"----------------------------EPOCH {epoch} PATCH {i}-------------------------------")
        patch_model.train()
        train_model(patch_model, patch_train_dataloaders_PET[i], optimiser, error, scheduler, True)

        patch_model.eval()
        val_loss = validate_model(patch_model, patch_val_dataloaders_PET[i], error, True)
        
        # Save model weights if improvement seen
        # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(patch_model.state_dict(), f"./patch_temp/PET_PATCH_TEMP{i}")
        else:
            no_improvement += 1
            if no_improvement <= patience:
                continue
            else:
                #print(f"BEST VAL ACC: {best_val_loss}")
                break
         
    #print("----------------------------TEST RESULTS-------------------------------")
    patch_model.load_state_dict(torch.load(f"./patch_temp/PET_PATCH_TEMP{i}"))
    patch_model.eval()
    scores = evaluate_model(patch_model, patch_test_dataloaders_PET[i], error, True,verbose=0)
    pet_accuracies.append(scores[0])
    pet_true_pos_rates.append(scores[1])
    pet_true_neg_rates.append(scores[2])
    pet_patch_models.append(patch_model)

In [13]:
pet_patch_scores = pd.DataFrame(list(zip([i for i in range(1,28)], pet_accuracies, pet_true_pos_rates, pet_true_neg_rates)),
              columns=['Patch Number','Test Accuracy','Test True Positive Rate', 'Test True Negative Rate'])
pet_patch_scores.set_index('Patch Number').round(2)

Unnamed: 0_level_0,Test Accuracy,Test True Positive Rate,Test True Negative Rate
Patch Number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,0.73,0.71,0.75
2,0.73,0.67,0.8
3,0.71,0.67,0.75
4,0.63,0.57,0.7
5,0.71,0.67,0.75
6,0.73,0.67,0.8
7,0.71,0.57,0.85
8,0.66,0.57,0.75
9,0.68,0.67,0.7
10,0.68,0.62,0.75


In [11]:
# Reload weights (for rerunning)
pet_models = []
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    patch_model.load_state_dict(torch.load(f"./patch_temp/PET_PATCH_TEMP{i}"))
    pet_models.append(patch_model)
    

# 2. Patch Fusion

The models trained below are at the whole image level. Features from the 27 models trained on each patch location will be concatenated and used for the final global classification.

In [12]:
# Stacking all 27 patch based datasets to create a single subject level dataset
# We will now have a subject level dataset where each row is a subject and their 27 patch images
# e.g row 1 would be: [patch_1,....,patch_27] for subject 1

all_patches_train_MRI = torch.stack([dataset.tensors[0] for dataset in patch_train_datasets_MRI], dim=1)
all_patches_val_MRI = torch.stack([dataset.tensors[0] for dataset in patch_val_datasets_MRI], dim=1)
all_patches_test_MRI = torch.stack([dataset.tensors[0] for dataset in patch_test_datasets_MRI], dim=1)

all_patches_train_PET = torch.stack([dataset.tensors[0] for dataset in patch_train_datasets_PET], dim=1)
all_patches_val_PET = torch.stack([dataset.tensors[0] for dataset in patch_val_datasets_PET], dim=1)
all_patches_test_PET = torch.stack([dataset.tensors[0] for dataset in patch_test_datasets_PET], dim=1)

In [13]:
batch_size = 16
all_patches_train_dataloader_MRI = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(all_patches_train_MRI, train_y), 
                                                                  batch_size = batch_size, shuffle = True)
all_patches_val_dataloader_MRI = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(all_patches_val_MRI, val_y), 
                                                                 batch_size = batch_size, shuffle = False)
all_patches_test_dataloader_MRI = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(all_patches_test_MRI, test_y), 
                                                                  batch_size = batch_size, shuffle = False)

all_patches_train_dataloader_PET = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(all_patches_train_PET, train_y), 
                                                                  batch_size = batch_size, shuffle = True)
all_patches_val_dataloader_PET = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(all_patches_val_PET, val_y), 
                                                                  batch_size = batch_size, shuffle = False)
all_patches_test_dataloader_PET = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(all_patches_test_PET, test_y), 
                                                                  batch_size = batch_size, shuffle = False)


In [14]:
# Define our final multi stream network
# This network uses the feature maps of the 27 patch models trained previously above as inputs
class patch_fusion_CNN(nn.Module):
    """
    This network uses the feature maps of the 27 patch models trained previously above as inputs.
    Feature maps are concatenated and passed through a dense layer.
    """
    def __init__(self, patch_models):
        super().__init__()
        self.patch_models = nn.ModuleList(patch_models)
        self.drop= nn.Dropout(p=0.4)
        
        self.fc1 = nn.Linear(2700, 1000) 
        self.rel1 = nn.ReLU()
        self.batch_norm1 = nn.BatchNorm1d(1000)
        
        self.fc2 = nn.Linear(1000, 200) 
        self.rel2 = nn.ReLU()
        self.batch_norm2 = nn.BatchNorm1d(200)
        
        self.fc3 = nn.Linear(2700, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        for i in range(27):
            patch_output, _ = self.patch_models[i](x[:, i])  
            patch_outputs.append(self.drop(patch_output))
        x = torch.cat(patch_outputs, dim=1)
        x = self.sigmoid(self.fc3(x))
        return x

### MRI ensemble patch model

In [17]:
MRI_patch_fusion_model = patch_fusion_CNN(mri_models).to(device)

#Freeze layers of the 27 individual patch models
for patch_model in MRI_patch_fusion_model.patch_models:
    for name, param in patch_model.named_parameters():      
            param.requires_grad = False 

# CNN model training
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# SGD Optimizer
optimiser = SGD(MRI_patch_fusion_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)

# Validation Hyperparameters for early stopping
best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(1000):
    #print(f"----------------------------EPOCH {epoch}-------------------------------")
    MRI_patch_fusion_model.train()
    train_model(MRI_patch_fusion_model, all_patches_train_dataloader_MRI, optimiser, error, scheduler)
    
    MRI_patch_fusion_model.eval()
    avg_val_loss = validate_model(MRI_patch_fusion_model, all_patches_val_dataloader_MRI, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(MRI_patch_model.state_dict(), "./trained_models_2/MRI_PATCH_MODEL")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
MRI_patch_fusion_model.load_state_dict(torch.load("./trained_models_2/MRI_PATCH_MODEL"))
MRI_patch_fusion_model.eval()
evaluate_model(MRI_patch_fusion_model, all_patches_test_dataloader_MRI, error)

BEST VAL LOSS: 0.04719100056624994
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.6829268292682927
True Positive Rate: 0.8095238095238095
True Negative Rate: 0.55


### PET ensemble patch model

In [15]:
PET_patch_fusion_model = patch_fusion_CNN(pet_models).to(device)

#Freeze layers of individual patch models
for patch_model in PET_patch_fusion_model.patch_models:
    for name, param in patch_model.named_parameters():
        param.requires_grad = False 

# CNN model training
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# SGD Optimizer
optimiser = SGD(PET_patch_fusion_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)

# Validation Hyperparameters for early stopping
best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(1000):
    #print(f"----------------------------EPOCH {epoch}-------------------------------")
    PET_patch_fusion_model.train()
    train_model(PET_patch_fusion_model, all_patches_train_dataloader_PET, optimiser, error, scheduler)
    
    PET_patch_fusion_model.eval()
    avg_val_loss = validate_model(PET_patch_fusion_model, all_patches_val_dataloader_PET, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(PET_patch_model.state_dict(), "./trained_models_2/1PET_PATCH_MODEL")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
PET_patch_fusion_model.load_state_dict(torch.load("./trained_models_2/1PET_PATCH_MODEL"))
PET_patch_fusion_model.eval()
evaluate_model(PET_patch_fusion_model, all_patches_test_dataloader_PET, error)

BEST VAL LOSS: 0.0446309679892005
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.7560975609756098
True Positive Rate: 0.6666666666666666
True Negative Rate: 0.85
Predictions: [tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1.]), tensor([0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
True Labels: [tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]), tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
