# Multimodal Patch Based networks

The following model expands on the a patch based architectures by combining feature maps of both MRI and PET modalities using attention based mechanisms. As a part of the ablation study, we also train a model without attention based mechanisms. There are three classifcation stages for the final multimodal attention-based model: 
- <b>Patch Feature Extraction:</b> 
    - 27 ResNet models are trained on patches from each patch location to extract local features.
- <b>Multimodal Attention:</b>
    - For each patch location, we train a model which combines the PET and MRI feature maps at that location using the corresponding patch models.  Multihead attention is used to capture the relationships between both modalities.
- <b>Patch fusion:</b>
    - Feature maps of the 27 trained attention-patch models are concatenated and used as inputs to a final model for global classification

## Importing Libraries

In [None]:
# importing the libraries
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import os
import gc
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import precision_score, recall_score, accuracy_score
from sklearn.preprocessing import StandardScaler

 
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import *
import nibabel as nb
import torchio as tio

from models import (cnn, C3DNet, resnet, ResNetV2, ResNeXt, ResNeXtV2, WideResNet, PreActResNet,
        EfficientNet, DenseNet, ShuffleNet, ShuffleNetV2, SqueezeNet, MobileNet, MobileNetV2)




## Read in images

In [None]:
# Subject IDs who are progressive normal cognition
PNC = pd.read_pickle('PNC.pkl')

# Subject IDs who are stable normal cognition
SNC = pd.read_pickle('SNC.pkl')

In [None]:
# Create datasets

def read_image_data(input_path):
    X_train = []
    y_train = []
    ids = []
    for filename in sorted(os.listdir(input_path)):
        file = os.path.join(input_path, filename)
        img_file = nb.load(file)
        img = img_file.get_fdata()
        X_train.append(img)
        ids.append(filename)
        # Progressive normal cognition target class 1
        if filename[0:8] in np.array(PNC):
            y_train.append(1)
        # Stable normal cognition class 0
        else:
            y_train.append(0)

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    # Reshape X_train to include channel dimension
    X_train = X_train.reshape(X_train.shape + (1,))
    return X_train, y_train, ids

X_MRI, y_MRI, ids_MRI = read_image_data("E:/Work/Processed_MRI/2.MNI_Registration")
X_PET, y_PET, ids_PET = read_image_data("E:/Work/Processed_PIB/4.MNI_Registered")


## Creating training and test sets for images

In [None]:
# 0.6 - 0.2 - 0.2 train-val-test split
X_train_MRI, X_test_MRI, X_train_PET, X_test_PET, ids_train, ids_test, y_train, y_test = train_test_split(X_MRI, 
                                                                                     X_PET, ids_MRI, 
                                                                                     y_MRI, test_size=0.2, 
                                                                                     random_state=101, stratify=y_MRI)

X_train_MRI, X_val_MRI, X_train_PET, X_val_PET, ids_train, ids_val, y_train, y_val = train_test_split(X_train_MRI, X_train_PET, 
                                                                                      ids_train, y_train, 
                                                                                      test_size=0.25, random_state=101,
                                                                                      stratify=y_train)

In [None]:
# Convert datasets to tensor format, with channel first
train_x_MRI = torch.from_numpy(X_train_MRI).float().permute(0,4,1,2,3)
train_x_PET = torch.from_numpy(X_train_PET).float().permute(0,4,1,2,3)
train_y = torch.from_numpy(y_train).float()

val_x_MRI = torch.from_numpy(X_val_MRI).float().permute(0,4,1,2,3)
val_x_PET = torch.from_numpy(X_val_PET).float().permute(0,4,1,2,3)
val_y = torch.from_numpy(y_val).float()

test_x_MRI = torch.from_numpy(X_test_MRI).float().permute(0,4,1,2,3)
test_x_PET = torch.from_numpy(X_test_PET).float().permute(0,4,1,2,3)
test_y = torch.from_numpy(y_test).float()

## Perform data augmentation to increase training set size

In [None]:
def perform_augmentation(dataset, seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Define transformations
    training_transform = tio.Compose([
        tio.RandomAffine(),
        tio.RandomElasticDeformation(),
        tio.RandomFlip()
    ])
    augmented_dataset = torch.clone(dataset) 
    for i in range(len(augmented_dataset)):
        augmented_dataset[i] = training_transform(augmented_dataset[i])
    return augmented_dataset

orig_train_x_MRI = torch.clone(train_x_MRI)
orig_train_x_PET = torch.clone(train_x_PET)
orig_train_y = torch.clone(train_y)
for seed in [1,101,42]:
    # Apply transformations and create augmented training set
    augmented_train_MRI = perform_augmentation(orig_train_x_MRI, seed)
    augmented_train_PET = perform_augmentation(orig_train_x_PET, seed)
    
    
    # Concatenate training and augmented training datasets
    train_x_MRI = torch.cat((train_x_MRI, augmented_train_MRI), 0)
    train_x_PET = torch.cat((train_x_PET, augmented_train_PET), 0)
    train_y = torch.cat((train_y, orig_train_y), 0)

# Removing blank space around brain to make dimensions even so we can divide image into uniform patches
train_x_MRI = train_x_MRI[:, :, 0:88, 0:108, 0:88]
train_x_PET = train_x_PET[:, :, 0:88, 0:108, 0:88]   

In [None]:
# Load augmented dataset (for reruns)
train_x_MRI = torch.load("train_x_MRI.pkl")
train_x_PET = torch.load("train_x_PET.pkl")
train_y = torch.load("train_y.pkl")

## Divide images into patches

In [None]:
# For each sample, divide images into 27 uniform 3x3x3 patches of size 44x54x44 with 50% overlap
# Then create 27 training datasets for each patch location 
# e.g a data set for patches of each subject in the top left corner, ..., a dataset for patches of each subject in the middle

def create_MRI_PET_patches(images, dim, labels):
    # Create datasets for each patch location
    datasets_MRI = [[] for _ in range(27)]
    datasets_PET = [[] for _ in range(27)]
    datasets_MRI_PET = [[] for _ in range(27)]
    # Find the starting locations of each patch
    starting_points = []
    for height_stride in range(3):
        for width_stride in range(3):
            for depth_stride in range(3):
                    start = (0 + height_stride*dim[1]//4, 0 + width_stride*dim[2]//4, 0 + depth_stride*dim[3]//4)
                    starting_points.append(start)

    # For each image
    for i in range(len(images[0])):
        # Create patch from every starting point
        for j in range(len(starting_points)):
            start_pt = starting_points[j]
            patch_MRI = images[0][i][:, start_pt[0]:start_pt[0] + dim[1]//2, 
                                      start_pt[1]:start_pt[1] + dim[2]//2, 
                                      start_pt[2]:start_pt[2] + dim[3]//2]
            
            patch_PET = images[1][i][:, start_pt[0]:start_pt[0] + dim[1]//2, 
                                      start_pt[1]:start_pt[1] + dim[2]//2, 
                                      start_pt[2]:start_pt[2] + dim[3]//2]
            datasets_MRI[j].append(patch_MRI)
            datasets_PET[j].append(patch_PET)
           
            
    # For each patch location, stack patches from every sample samples into a tensor as a Tensor Dataset
    for i in range(27):
        datasets_MRI[i] = torch.stack(datasets_MRI[i])
        datasets_PET[i] = torch.stack(datasets_PET[i]) 
        datasets_MRI_PET[i] = torch.utils.data.TensorDataset(datasets_MRI[i], datasets_PET[i], labels)
    return datasets_MRI_PET


In [None]:
# Create datasets for each patch location

dim = train_x_MRI[0].shape
patch_train_datasets_MRI_PET = create_MRI_PET_patches((train_x_MRI, train_x_PET), dim, train_y)
patch_val_datasets_MRI_PET = create_MRI_PET_patches((val_x_MRI, val_x_PET), dim, val_y)
patch_test_datasets_MRI_PET = create_MRI_PET_patches((test_x_MRI, test_x_PET), dim, test_y)


# Create data loaders for each patch location dataset
patch_train_loaders_MRI_PET = []
patch_val_loaders_MRI_PET = []
patch_test_loaders_MRI_PET = []

In [None]:
batch_size = 16
for i in range(27):
    patch_train_loaders_MRI_PET.append(torch.utils.data.DataLoader(patch_train_datasets_MRI_PET[i], 
                                                                  batch_size = batch_size, shuffle = True))
    patch_val_loaders_MRI_PET.append(torch.utils.data.DataLoader(patch_val_datasets_MRI_PET[i], 
                                                                  batch_size = batch_size, shuffle = False))
    patch_test_loaders_MRI_PET.append(torch.utils.data.DataLoader(patch_test_datasets_MRI_PET[i], 
                                                                  batch_size = batch_size, shuffle = False))

In [None]:
# Define device being used for model training
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else: 
    dev = "cpu" 
device = torch.device(dev) 

## Model training

In [None]:
def train_model(model, train_loader, optimiser, error, scheduler, feature_maps=False):
    total_train_loss = 0
    for MRI_image, PET_image, labels in train_loader:
        MRI_image = MRI_image.to(device)
        PET_image = PET_image.to(device)
        labels = labels.to(device)
        # Clear gradients
        optimiser.zero_grad()
        
        # Forward propagation
        # feature_map indicates whether the feature map prior final layer is also included in the outputs of the model
        if feature_maps:
            outputs = model((MRI_image, PET_image))[1]
        else:
            outputs = model((MRI_image, PET_image))
        
        # Calculate loss
        loss = error(outputs.flatten(), labels)
        total_train_loss += loss.item()
        
        # Calculating gradients
        loss.backward()

        # Update parameters
        optimiser.step()
        scheduler.step()
        
        # Clear GPU Cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # Calculate accuracy
        preds = outputs.flatten().round()
    print("Average Training Loss", total_train_loss/len(train_loader.dataset))

def validate_model(model, val_loader, error, feature_maps=False):
    correct_predictions_val = 0
    total_val_loss = 0
    with torch.no_grad():
        for MRI_image, PET_image, labels in val_loader:
            MRI_image = MRI_image.to(device)
            PET_image = PET_image.to(device)
            labels = labels.to(device)

            # Forward propagation
            if feature_maps:
                pred = model((MRI_image, PET_image))[1]
            else:
                pred = model((MRI_image, PET_image))
            
            # Calculate loss
            loss = error(pred.flatten(), labels)
            total_val_loss += loss.item()

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            preds = pred.flatten().round()
            #print(pred)
            correct_predictions_val += torch.sum(preds == labels).item()
        print("Validation Accuracy:", correct_predictions_val/len(val_loader.dataset))
        print("Average Validation Loss:", total_val_loss/len(val_loader.dataset))
    return total_val_loss/len(val_loader.dataset)
        
def evaluate_model(model, test_loader, error, feature_maps=False):
    correct_predictions_test = 0
    preds = []
    labels = []
    with torch.no_grad():
        for MRI_image, PET_image, label in test_loader:
            MRI_image = MRI_image.to(device)
            PET_image = PET_image.to(device)
            
            labels.append(label.cpu())
            
            # Forward propagation
            if feature_maps:
                pred = model((MRI_image, PET_image))[1]
            else:
                pred = model((MRI_image, PET_image))

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            pred = pred.flatten().round()
            preds.append(pred.cpu())
            correct_predictions_test += torch.sum(pred.cpu() == label.cpu()).item()
    print("Test Accuracy:", correct_predictions_test/len(test_loader.dataset))    
    print("Recall:", recall_score(torch.cat(labels), torch.cat(preds)))
    print("Precision:", precision_score(torch.cat(labels), torch.cat(preds)))
    print(preds)
    print(labels)

## Stage 1: Patch Feature Extraction

### Load MRI patch models for each patch location

In [None]:
# Reload weights for rerunning
MRI_models = []
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    patch_model.load_state_dict(torch.load(f"./patch_temp/PATCH_TEMP{i}"))
    MRI_models.append(patch_model)
    

### Load PET patch Models for each patch location

In [None]:
# Reload weights for rerunning
PET_models = []
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    patch_model.load_state_dict(torch.load(f"./patch_temp/PET_PATCH_TEMP{i}"))
    PET_models.append(patch_model)
    

## Stage 2 :Multimodal Attention Models

The models trained below are at the patch level. A model is trained for each patch location. Feature maps from the trained MRI and PET patch models at the same patch location are combined for classification.

In [None]:
# Define multimodal patch model
# This combines the features from the individual MRI and PET patch models at a patch location

class multimodal_patch_CNN(nn.Module):
    def __init__(self, patch_models):
        super().__init__()
        self.patch_models = nn.ModuleList(patch_models)
        self.drop= nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(200, 100) 
        self.rel1 = nn.ReLU()
        self.batch_norm1 = nn.BatchNorm1d(100)
        
        self.fc2 = nn.Linear(100, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        for i in range(2):
            patch_output = self.patch_models[i](x[i])[0]  
            patch_outputs.append(self.drop(patch_output))
        output = torch.cat(patch_outputs, dim=1)
        output = self.rel1(self.fc1(output))
        features = self.batch_norm1(output)
        output = self.sigmoid(self.fc2(features))
        return features, output

In [None]:
# Define multimodal patch model wtih Attention based mechanism
# This combines the features from the individual MRI and PET patch models at a patch location

class multimodal_patch_CNN_attention(nn.Module):
    def __init__(self, patch_models):
        super().__init__()
        self.patch_models = nn.ModuleList(patch_models)
        self.drop= nn.Dropout(p=0.4)
    
        self.att = nn.MultiheadAttention(embed_dim=100, num_heads=4,dropout=0.4)
        self.fc1 = nn.Linear(200, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        #attention_patch_outputs = []
        for i in range(2):
            patch_output = self.patch_models[i](x[i])[0]  
            patch_outputs.append(self.drop(patch_output))
    
        x = torch.stack(patch_outputs)
        x, _ = self.att(x,x,x)
        x = x.permute(1,0,2)
        features = x.reshape(x.shape[0],200)
        output = self.sigmoid(self.fc1(features))
        return features, output

### Multimodal with Attention

In [None]:
# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)


MRI_PET_patch_models = []

# For each patch location, concatenate outputs of trained patch models for each modality and learn local features
for i in range(27):
    MRI_patch_model = MRI_models[i]
    PET_patch_model = PET_models[i]
    MRI_PET_patch_model = multimodal_patch_CNN_attention((MRI_patch_model,PET_patch_model)).to(device)
    
    # Freeze MRI and PET patch model layers (we only want the feature maps)
    for name, param in MRI_patch_model.named_parameters():
            param.requires_grad = False 
            
    for name, param in PET_patch_model.named_parameters():
            param.requires_grad = False 
    
    
    # Instantiate dataloaders
    MRI_PET_train_dataloaders = patch_train_loaders_MRI_PET[i]
    MRI_PET_val_dataloaders = patch_val_loaders_MRI_PET[i]
    MRI_PET_test_dataloaders = patch_test_loaders_MRI_PET[i]
    
    best_val_loss = np.inf
    patience = 15
    no_improvement = 0
    
    # Binary Cross Entropy Loss
    error = nn.BCELoss()

    # SGD Optimizer
    optimiser = SGD(MRI_PET_patch_model.parameters(), lr=0.0001, momentum=0.99)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)
    
    for epoch in range(500):
        print(f"----------------------------EPOCH {epoch} PATCH {i}-------------------------------")
        MRI_PET_patch_model.train()
        train_model(MRI_PET_patch_model, MRI_PET_train_dataloaders, optimiser, error, scheduler, True)

        MRI_PET_patch_model.eval()
        val_loss = validate_model(MRI_PET_patch_model, MRI_PET_val_dataloaders, error, True)
        
        # Save model weights if improvement seen
        # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(MRI_PET_patch_model.state_dict(), f"./patch_temp/multimodal_PATCH_TEMP_attention{i}")
        else:
            no_improvement += 1
            if no_improvement <= patience:
                continue
            else:
                print(f"BEST VAL ACC: {best_val_loss}")
                break
         
    print("----------------------------TEST RESULTS-------------------------------")
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP_attention{i}"))
    MRI_PET_patch_model.eval()
    evaluate_model(MRI_PET_patch_model, MRI_PET_test_dataloaders, error, True)
    MRI_PET_patch_models.append(MRI_PET_patch_model)

### Mutlimodal without attention

In [None]:
# Set Seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)


MRI_PET_patch_models = []

# For each patch location, concatenate outputs of trained patch models for each modality and learn local features
for i in range(27):
    MRI_patch_model = MRI_models[i]
    PET_patch_model = PET_models[i]
    MRI_PET_patch_model = multimodal_patch_CNN((MRI_patch_model,PET_patch_model)).to(device)
    
    # Freeze MRI and PET patch model layers (we only want the feature maps)
    for name, param in MRI_patch_model.named_parameters():
            param.requires_grad = False 
            
    for name, param in PET_patch_model.named_parameters():
            param.requires_grad = False 
    
    # Instantiate dataloaders
    MRI_PET_train_dataloaders = patch_train_loaders_MRI_PET[i]
    MRI_PET_val_dataloaders = patch_val_loaders_MRI_PET[i]
    MRI_PET_test_dataloaders = patch_test_loaders_MRI_PET[i]
    
    best_val_loss = np.inf
    patience = 15
    no_improvement = 0
    
    # Binary Cross Entropy Loss
    error = nn.BCELoss()

    # SGD Optimizer
    optimiser = SGD(MRI_PET_patch_model.parameters(), lr=0.0001, momentum=0.99)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)
    
    for epoch in range(500):
        print(f"----------------------------EPOCH {epoch} PATCH {i}-------------------------------")
        MRI_PET_patch_model.train()
        train_model(MRI_PET_patch_model, MRI_PET_train_dataloaders, optimiser, error, scheduler, True)

        MRI_PET_patch_model.eval()
        val_loss = validate_model(MRI_PET_patch_model, MRI_PET_val_dataloaders, error, True)
        
        # Save model weights if improvement seen
        # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(MRI_PET_patch_model.state_dict(), f"./patch_temp/multimodal_PATCH_TEMP{i}")
        else:
            no_improvement += 1
            if no_improvement <= patience:
                continue
            else:
                print(f"BEST VAL ACC: {best_val_loss}")
                break
         
    print("----------------------------TEST RESULTS-------------------------------")
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP{i}"))
    MRI_PET_patch_model.eval()
    evaluate_model(MRI_PET_patch_model, MRI_PET_test_dataloaders, error, True)
    MRI_PET_patch_models.append(MRI_PET_patch_model)

# Stage 3. Patch Fusion

In [None]:
# Stacking all 27 patch based datasets to create a single subject level dataset
# We will now have a subject level dataset where each row is a subject and their 27 patch images
# e.g row 1 would be: [patch_1,....,patch_27] for subject 1

all_patch_train_MRI = torch.stack([dataset.tensors[0] for dataset in patch_train_datasets_MRI_PET], dim=1)
all_patch_val_MRI = torch.stack([dataset.tensors[0] for dataset in patch_val_datasets_MRI_PET], dim=1)
all_patch_test_MRI = torch.stack([dataset.tensors[0] for dataset in patch_test_datasets_MRI_PET], dim=1)

all_patch_train_PET = torch.stack([dataset.tensors[1] for dataset in patch_train_datasets_MRI_PET], dim=1)
all_patch_val_PET = torch.stack([dataset.tensors[1] for dataset in patch_val_datasets_MRI_PET], dim=1)
all_patch_test_PET = torch.stack([dataset.tensors[1] for dataset in patch_test_datasets_MRI_PET], dim=1)

all_patch_train_MRI_PET_dataset  = torch.utils.data.TensorDataset(all_patch_train_MRI, all_patch_train_PET, train_y)
all_patch_val_MRI_PET_dataset = torch.utils.data.TensorDataset(all_patch_val_MRI, all_patch_val_PET, val_y)
all_patch_test_MRI_PET_dataset = torch.utils.data.TensorDataset(all_patch_test_MRI, all_patch_test_PET, test_y)

In [None]:
batch_size = 10
all_patch_train_dataloader_MRI_PET = torch.utils.data.DataLoader(all_patch_train_MRI_PET_dataset, 
                                                                batch_size = batch_size, shuffle = True)
all_patch_val_dataloader_MRI_PET = torch.utils.data.DataLoader(all_patch_val_MRI_PET_dataset, 
                                                                batch_size = batch_size, shuffle = True)
all_patch_test_dataloader_MRI_PET = torch.utils.data.DataLoader(all_patch_test_MRI_PET_dataset, 
                                                                batch_size = batch_size, shuffle = False)



In [None]:
# Define our final multimodal attention-based patch network
class multimodal_CNN(nn.Module):
    def __init__(self, MRI_PET_patch_models):
        super().__init__()
        self.MRI_PET_patch_models = nn.ModuleList(MRI_PET_patch_models)
        self.drop= nn.Dropout(p=0.4)
        self.fc3 = nn.Linear(5400, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        for i in range(27):
            mri_image = x[0][:, i]
            pet_image = x[1][:, i]
            patch_output, _ = self.MRI_PET_patch_models[i]((mri_image, pet_image))
            patch_outputs.append(self.drop(patch_output))
        x = torch.cat(patch_outputs, dim=1)
        x = self.sigmoid(self.fc3(x))
        return x

### Final Multimodal Patch Model with attention

In [None]:
# Reload attention weights for rerunning
MRI_PET_patch_models = []
for i in range(27):
    MRI_PET_patch_model = multimodal_patch_CNN_attention((MRI_models[i],PET_models[i])).to(device)
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP_attention{i}"))                                           
    MRI_PET_patch_models.append(MRI_PET_patch_model)
    
MRI_patch_model = multimodal_CNN(MRI_PET_patch_models).to(device)

# Freeze layers of the patch models (we only want the feature maps)
for patch_model in MRI_patch_model.MRI_PET_patch_models:
    for name, param in patch_model.named_parameters():
        param.requires_grad = False 

# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# SGD Optimizer
optimiser = SGD(MRI_patch_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)


best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(1000):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    MRI_patch_model.train()
    train_model(MRI_patch_model, all_patch_train_dataloader_MRI_PET, optimiser, error, scheduler)
    
    MRI_patch_model.eval()
    avg_val_loss = validate_model(MRI_patch_model, all_patch_val_dataloader_MRI_PET, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(MRI_patch_model.state_dict(), "./trained_models_2/MULTIMODAL_MODEL_MODAL_ATTENTION")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
MRI_patch_model.load_state_dict(torch.load("./trained_models_2/MULTIMODAL_MODEL_MODAL_ATTENTION"))
MRI_patch_model.eval()
evaluate_model(MRI_patch_model, all_patch_test_dataloader_MRI_PET, error)

### Final Multimodal Patch Model with no attention

In [None]:
# Reload no attention weights for rerunning
MRI_PET_patch_models = []
for i in range(27):
    MRI_PET_patch_model = multimodal_patch_CNN((MRI_models[i],PET_models[i])).to(device)
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP{i}"))                                           
    MRI_PET_patch_models.append(MRI_PET_patch_model)
    
MRI_patch_model = multimodal_CNN(MRI_PET_patch_models).to(device)

# Freeze layers of the patch models (we only want the feature maps)
for patch_model in MRI_patch_model.MRI_PET_patch_models:
    for name, param in patch_model.named_parameters():
        param.requires_grad = False 

# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# SGD Optimizer
optimiser = SGD(MRI_patch_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)


best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(1000):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    MRI_patch_model.train()
    train_model(MRI_patch_model, all_patch_train_dataloader_MRI_PET, optimiser, error, scheduler)
    
    MRI_patch_model.eval()
    avg_val_loss = validate_model(MRI_patch_model, all_patch_val_dataloader_MRI_PET, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(MRI_patch_model.state_dict(), "./trained_models_2/MULTIMODAL_MODEL_4")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
MRI_patch_model.load_state_dict(torch.load("./trained_models_2/MULTIMODAL_MODEL_4"))
MRI_patch_model.eval()
evaluate_model(MRI_patch_model, all_patch_test_dataloader_MRI_PET, error)