# Multimodal Patch Based networks

The following model expands on the a patch based architectures by combining feature maps of both MRI and PET modalities using attention based mechanisms. As a part of the ablation study, we also train a model without attention based mechanisms. There are three classifcation stages for the final multimodal attention-based model: 
- <b>Patch Feature Extraction:</b> 
    - 27 ResNet models are trained on patches from each patch location to extract local features.
- <b>Multimodal Attention:</b>
    - For each patch location, we train a model which combines the PET and MRI feature maps at that location using the corresponding patch models.  Multihead attention is used to capture the relationships between both modalities.
- <b>Patch fusion:</b>
    - Feature maps of the 27 trained attention-patch models are concatenated and used as inputs to a final model for global classification

## Importing Libraries

In [1]:
# importing the libraries
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import os
import gc
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import precision_score, recall_score, accuracy_score
from sklearn.preprocessing import StandardScaler

 
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import *
import nibabel as nb
import torchio as tio

from models import (cnn, C3DNet, resnet, ResNetV2, ResNeXt, ResNeXtV2, WideResNet, PreActResNet,
        EfficientNet, DenseNet, ShuffleNet, ShuffleNetV2, SqueezeNet, MobileNet, MobileNetV2)




## Read in images

In [2]:
# Subject IDs who are progressive normal cognition
PNC = pd.read_pickle('PNC.pkl')

# Subject IDs who are stable normal cognition
SNC = pd.read_pickle('SNC.pkl')

In [3]:
# Create datasets

def read_image_data(input_path):
    X_train = []
    y_train = []
    ids = []
    for filename in sorted(os.listdir(input_path)):
        file = os.path.join(input_path, filename)
        img_file = nb.load(file)
        img = img_file.get_fdata()
        X_train.append(img)
        ids.append(filename)
        # Progressive normal cognition target class 1
        if filename[0:8] in np.array(PNC):
            y_train.append(1)
        # Stable normal cognition class 0
        else:
            y_train.append(0)

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    # Reshape X_train to include channel dimension
    X_train = X_train.reshape(X_train.shape + (1,))
    return X_train, y_train, ids

X_MRI, y_MRI, ids_MRI = read_image_data("E:/Work/Processed_MRI/2.MNI_Registration")
X_PET, y_PET, ids_PET = read_image_data("E:/Work/Processed_PIB/4.MNI_Registered")


## Creating training and test sets for images

In [4]:
# 0.6 - 0.2 - 0.2 train-val-test split
X_train_MRI, X_test_MRI, X_train_PET, X_test_PET, ids_train, ids_test, y_train, y_test = train_test_split(X_MRI, 
                                                                                     X_PET, ids_MRI, 
                                                                                     y_MRI, test_size=0.2, 
                                                                                     random_state=101, stratify=y_MRI)

X_train_MRI, X_val_MRI, X_train_PET, X_val_PET, ids_train, ids_val, y_train, y_val = train_test_split(X_train_MRI, X_train_PET, 
                                                                                      ids_train, y_train, 
                                                                                      test_size=0.25, random_state=101,
                                                                                      stratify=y_train)

In [5]:
# Convert datasets to tensor format, with channel first
train_x_MRI = torch.from_numpy(X_train_MRI).float().permute(0,4,1,2,3)
train_x_PET = torch.from_numpy(X_train_PET).float().permute(0,4,1,2,3)
train_y = torch.from_numpy(y_train).float()

val_x_MRI = torch.from_numpy(X_val_MRI).float().permute(0,4,1,2,3)
val_x_PET = torch.from_numpy(X_val_PET).float().permute(0,4,1,2,3)
val_y = torch.from_numpy(y_val).float()

test_x_MRI = torch.from_numpy(X_test_MRI).float().permute(0,4,1,2,3)
test_x_PET = torch.from_numpy(X_test_PET).float().permute(0,4,1,2,3)
test_y = torch.from_numpy(y_test).float()

## Perform data augmentation to increase training set size

In [None]:
def perform_augmentation(dataset, seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Define transformations
    training_transform = tio.Compose([
        tio.RandomAffine(),
        tio.RandomElasticDeformation(),
        tio.RandomFlip()
    ])
    augmented_dataset = torch.clone(dataset) 
    for i in range(len(augmented_dataset)):
        augmented_dataset[i] = training_transform(augmented_dataset[i])
    return augmented_dataset

orig_train_x_MRI = torch.clone(train_x_MRI)
orig_train_x_PET = torch.clone(train_x_PET)
orig_train_y = torch.clone(train_y)
for seed in [1,101,42]:
    # Apply transformations and create augmented training set
    augmented_train_MRI = perform_augmentation(orig_train_x_MRI, seed)
    augmented_train_PET = perform_augmentation(orig_train_x_PET, seed)
    
    
    # Concatenate training and augmented training datasets
    train_x_MRI = torch.cat((train_x_MRI, augmented_train_MRI), 0)
    train_x_PET = torch.cat((train_x_PET, augmented_train_PET), 0)
    train_y = torch.cat((train_y, orig_train_y), 0)

# Removing blank space around brain to make dimensions even so we can divide image into uniform patches
train_x_MRI = train_x_MRI[:, :, 0:88, 0:108, 0:88]
train_x_PET = train_x_PET[:, :, 0:88, 0:108, 0:88]   

In [6]:
# Load augmented dataset (for reruns)
train_x_MRI = torch.load("train_x_MRI.pkl")
train_x_PET = torch.load("train_x_PET.pkl")
train_y = torch.load("train_y.pkl")

## Divide images into patches

In [7]:
# For each sample, divide images into 27 uniform 3x3x3 patches of size 44x54x44 with 50% overlap
# Then create 27 training datasets for each patch location 
# e.g a data set for patches of each subject in the top left corner, ..., a dataset for patches of each subject in the middle

def create_MRI_PET_patches(images, dim, labels):
    # Create datasets for each patch location
    datasets_MRI = [[] for _ in range(27)]
    datasets_PET = [[] for _ in range(27)]
    datasets_MRI_PET = [[] for _ in range(27)]
    # Find the starting locations of each patch
    starting_points = []
    for height_stride in range(3):
        for width_stride in range(3):
            for depth_stride in range(3):
                    start = (0 + height_stride*dim[1]//4, 0 + width_stride*dim[2]//4, 0 + depth_stride*dim[3]//4)
                    starting_points.append(start)

    # For each image
    for i in range(len(images[0])):
        # Create patch from every starting point
        for j in range(len(starting_points)):
            start_pt = starting_points[j]
            patch_MRI = images[0][i][:, start_pt[0]:start_pt[0] + dim[1]//2, 
                                      start_pt[1]:start_pt[1] + dim[2]//2, 
                                      start_pt[2]:start_pt[2] + dim[3]//2]
            
            patch_PET = images[1][i][:, start_pt[0]:start_pt[0] + dim[1]//2, 
                                      start_pt[1]:start_pt[1] + dim[2]//2, 
                                      start_pt[2]:start_pt[2] + dim[3]//2]
            datasets_MRI[j].append(patch_MRI)
            datasets_PET[j].append(patch_PET)
           
            
    # For each patch location, stack patches from every sample samples into a tensor as a Tensor Dataset
    for i in range(27):
        datasets_MRI[i] = torch.stack(datasets_MRI[i])
        datasets_PET[i] = torch.stack(datasets_PET[i]) 
        datasets_MRI_PET[i] = torch.utils.data.TensorDataset(datasets_MRI[i], datasets_PET[i], labels)
    return datasets_MRI_PET


In [8]:
# Create datasets for each patch location

dim = train_x_MRI[0].shape
patch_train_datasets_MRI_PET = create_MRI_PET_patches((train_x_MRI, train_x_PET), dim, train_y)
patch_val_datasets_MRI_PET = create_MRI_PET_patches((val_x_MRI, val_x_PET), dim, val_y)
patch_test_datasets_MRI_PET = create_MRI_PET_patches((test_x_MRI, test_x_PET), dim, test_y)


# Create data loaders for each patch location dataset
patch_train_loaders_MRI_PET = []
patch_val_loaders_MRI_PET = []
patch_test_loaders_MRI_PET = []

In [9]:
batch_size = 16
for i in range(27):
    patch_train_loaders_MRI_PET.append(torch.utils.data.DataLoader(patch_train_datasets_MRI_PET[i], 
                                                                  batch_size = batch_size, shuffle = True))
    patch_val_loaders_MRI_PET.append(torch.utils.data.DataLoader(patch_val_datasets_MRI_PET[i], 
                                                                  batch_size = batch_size, shuffle = False))
    patch_test_loaders_MRI_PET.append(torch.utils.data.DataLoader(patch_test_datasets_MRI_PET[i], 
                                                                  batch_size = batch_size, shuffle = False))

In [10]:
# Define device being used for model training
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else: 
    dev = "cpu" 
device = torch.device(dev) 

## Model training

In [11]:
def train_model(model, train_loader, optimiser, error, scheduler, feature_maps=False):
    total_train_loss = 0
    for MRI_image, PET_image, labels in train_loader:
        MRI_image = MRI_image.to(device)
        PET_image = PET_image.to(device)
        labels = labels.to(device)
        # Clear gradients
        optimiser.zero_grad()
        
        # Forward propagation
        # feature_map indicates whether the feature map prior final layer is also included in the outputs of the model
        if feature_maps:
            outputs = model((MRI_image, PET_image))[1]
        else:
            outputs = model((MRI_image, PET_image))
        
        # Calculate loss
        loss = error(outputs.flatten(), labels)
        total_train_loss += loss.item()
        
        # Calculating gradients
        loss.backward()

        # Update parameters
        optimiser.step()
        scheduler.step()
        
        # Clear GPU Cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # Calculate accuracy
        preds = outputs.flatten().round()
    print("Average Training Loss", total_train_loss/len(train_loader.dataset))

def validate_model(model, val_loader, error, feature_maps=False):
    correct_predictions_val = 0
    total_val_loss = 0
    with torch.no_grad():
        for MRI_image, PET_image, labels in val_loader:
            MRI_image = MRI_image.to(device)
            PET_image = PET_image.to(device)
            labels = labels.to(device)

            # Forward propagation
            if feature_maps:
                pred = model((MRI_image, PET_image))[1]
            else:
                pred = model((MRI_image, PET_image))
            
            # Calculate loss
            loss = error(pred.flatten(), labels)
            total_val_loss += loss.item()

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            preds = pred.flatten().round()
            #print(pred)
            correct_predictions_val += torch.sum(preds == labels).item()
        print("Validation Accuracy:", correct_predictions_val/len(val_loader.dataset))
        print("Average Validation Loss:", total_val_loss/len(val_loader.dataset))
    return total_val_loss/len(val_loader.dataset)
        
def evaluate_model(model, test_loader, error, feature_maps=False):
    correct_predictions_test = 0
    preds = []
    labels = []
    with torch.no_grad():
        for MRI_image, PET_image, label in test_loader:
            MRI_image = MRI_image.to(device)
            PET_image = PET_image.to(device)
            
            labels.append(label.cpu())
            
            # Forward propagation
            if feature_maps:
                pred = model((MRI_image, PET_image))[1]
            else:
                pred = model((MRI_image, PET_image))

            # Clear GPU Cache
            torch.cuda.empty_cache()
            gc.collect()

            # Calculate accuracy
            pred = pred.flatten().round()
            preds.append(pred.cpu())
            correct_predictions_test += torch.sum(pred.cpu() == label.cpu()).item()
    print("Test Accuracy:", correct_predictions_test/len(test_loader.dataset))    
    print("Recall:", recall_score(torch.cat(labels), torch.cat(preds)))
    print("Precision:", precision_score(torch.cat(labels), torch.cat(preds)))
    print(preds)
    print(labels)

## Stage 1: Patch Feature Extraction

### Load MRI patch models for each patch location

In [12]:
# Reload weights for rerunning
MRI_models = []
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    patch_model.load_state_dict(torch.load(f"./patch_temp/PATCH_TEMP{i}"))
    MRI_models.append(patch_model)
    

### Load PET patch Models for each patch location

In [13]:
# Reload weights for rerunning
PET_models = []
for i in range(27):
    patch_model = ResNetV2.generate_model(
        model_depth=10,
        n_classes=1,
        n_input_channels=1,
        shortcut_type='B',
        conv1_t_size=7,
        conv1_t_stride=2,
        no_max_pool=False,
        widen_factor=1.0).to(device)
    patch_model.load_state_dict(torch.load(f"./patch_temp/PET_PATCH_TEMP{i}"))
    PET_models.append(patch_model)
    

## Stage 2 :Multimodal Attention Models

The models trained below are at the patch level. A model is trained for each patch location. Feature maps from the trained MRI and PET patch models at the same patch location are combined for classification.

In [14]:
# Define multimodal patch model
# This combines the features from the individual MRI and PET patch models at a patch location

class multimodal_patch_CNN(nn.Module):
    def __init__(self, patch_models):
        super().__init__()
        self.patch_models = nn.ModuleList(patch_models)
        self.drop= nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(200, 100) 
        self.rel1 = nn.ReLU()
        self.batch_norm1 = nn.BatchNorm1d(100)
        
        self.fc2 = nn.Linear(100, 50) 
        self.rel2 = nn.ELU()
        self.batch_norm2 = nn.BatchNorm1d(50)
        
        self.fc3 = nn.Linear(100, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        for i in range(2):
            patch_output = self.patch_models[i](x[i])[0]  
            patch_outputs.append(self.drop(patch_output))
        output = torch.cat(patch_outputs, dim=1)
        output = self.rel1(self.fc1(output))
        features = self.batch_norm1(output)
        output = self.sigmoid(self.fc3(features))
        return features, output

In [15]:
# Define multimodal patch model wtih Attention based mechanism
# This combines the features from the individual MRI and PET patch models at a patch location

class multimodal_patch_CNN_attention(nn.Module):
    def __init__(self, patch_models):
        super().__init__()
        self.patch_models = nn.ModuleList(patch_models)
        self.drop= nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(200, 100) 
        self.rel1 = nn.ReLU()
        self.batch_norm1 = nn.BatchNorm1d(100)
        
        self.fc2 = nn.Linear(100, 50) 
        self.rel2 = nn.ELU()
        self.batch_norm2 = nn.BatchNorm1d(50)
        self.att = nn.MultiheadAttention(embed_dim=100, num_heads=4,dropout=0.4)
        self.fc3 = nn.Linear(200, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        #attention_patch_outputs = []
        for i in range(2):
            patch_output = self.patch_models[i](x[i])[0]  
            patch_outputs.append(self.drop(patch_output))
    
        x = torch.stack(patch_outputs)
        x, _ = self.att(x,x,x)
        x = x.permute(1,0,2)
        features = x.reshape(x.shape[0],200)
        output = self.sigmoid(self.fc3(features))
        return features, output

### Multimodal with Attention

In [25]:
# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)


MRI_PET_patch_models = []

# For each patch location, concatenate outputs of trained patch models for each modality and learn local features
for i in range(27):
    MRI_patch_model = MRI_models[i]
    PET_patch_model = PET_models[i]
    MRI_PET_patch_model = multimodal_patch_CNN_attention((MRI_patch_model,PET_patch_model)).to(device)
    
    # Freeze MRI and PET patch model layers (we only want the feature maps)
    for name, param in MRI_patch_model.named_parameters():
            param.requires_grad = False 
            
    for name, param in PET_patch_model.named_parameters():
            param.requires_grad = False 
    
    
    # Instantiate dataloaders
    MRI_PET_train_dataloaders = patch_train_loaders_MRI_PET[i]
    MRI_PET_val_dataloaders = patch_val_loaders_MRI_PET[i]
    MRI_PET_test_dataloaders = patch_test_loaders_MRI_PET[i]
    
    best_val_loss = np.inf
    patience = 15
    no_improvement = 0
    
    # Binary Cross Entropy Loss
    error = nn.BCELoss()

    # SGD Optimizer
    optimiser = SGD(MRI_PET_patch_model.parameters(), lr=0.0001, momentum=0.99)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)
    
    for epoch in range(500):
        print(f"----------------------------EPOCH {epoch} PATCH {i}-------------------------------")
        MRI_PET_patch_model.train()
        train_model(MRI_PET_patch_model, MRI_PET_train_dataloaders, optimiser, error, scheduler, True)

        MRI_PET_patch_model.eval()
        val_loss = validate_model(MRI_PET_patch_model, MRI_PET_val_dataloaders, error, True)
        
        # Save model weights if improvement seen
        # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(MRI_PET_patch_model.state_dict(), f"./patch_temp/multimodal_PATCH_TEMP_attention{i}")
        else:
            no_improvement += 1
            if no_improvement <= patience:
                continue
            else:
                print(f"BEST VAL ACC: {best_val_loss}")
                break
         
    print("----------------------------TEST RESULTS-------------------------------")
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP_attention{i}"))
    MRI_PET_patch_model.eval()
    evaluate_model(MRI_PET_patch_model, MRI_PET_test_dataloaders, error, True)
    MRI_PET_patch_models.append(MRI_PET_patch_model)

----------------------------EPOCH 0 PATCH 0-------------------------------
Average Training Loss 0.046272377620954985
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05150613116055
----------------------------EPOCH 1 PATCH 0-------------------------------
Average Training Loss 0.045924445400472545
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05143289013606746
----------------------------EPOCH 2 PATCH 0-------------------------------
Average Training Loss 0.04530033469200134
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05130598893979701
----------------------------EPOCH 3 PATCH 0-------------------------------
Average Training Loss 0.04505757120300512
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.051160318095509597
----------------------------EPOCH 4 PATCH 0-------------------------------
Average Training Loss 0.04461041935643212
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05099529

Average Training Loss 0.023580997754804423
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04565386365099651
----------------------------EPOCH 41 PATCH 0-------------------------------
Average Training Loss 0.023899576153422964
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.045269384616758765
----------------------------EPOCH 42 PATCH 0-------------------------------
Average Training Loss 0.0231407105189855
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04524412387754859
----------------------------EPOCH 43 PATCH 0-------------------------------
Average Training Loss 0.022323761257480403
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.045381596902521644
----------------------------EPOCH 44 PATCH 0-------------------------------
Average Training Loss 0.021055278536237655
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04510545875968003
----------------------------EPOCH 45 PATCH 0------------

Average Training Loss 0.042039020994647604
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05010557174682617
----------------------------EPOCH 13 PATCH 1-------------------------------
Average Training Loss 0.0422254916830141
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05005763507470852
----------------------------EPOCH 14 PATCH 1-------------------------------
Average Training Loss 0.04200727297145812
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05001290978454962
----------------------------EPOCH 15 PATCH 1-------------------------------
Average Training Loss 0.041623757755170104
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.04995088897100309
----------------------------EPOCH 16 PATCH 1-------------------------------
Average Training Loss 0.04151722361318401
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.049904116770116295
----------------------------EPOCH 17 PATCH 1---------------

Average Training Loss 0.03497495221309974
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04801954874178258
----------------------------EPOCH 54 PATCH 1-------------------------------
Average Training Loss 0.0344663572604539
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04794392498528085
----------------------------EPOCH 55 PATCH 1-------------------------------
Average Training Loss 0.0346752580438481
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04795718338431382
----------------------------EPOCH 56 PATCH 1-------------------------------
Average Training Loss 0.034316205404332424
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047820220633250914
----------------------------EPOCH 57 PATCH 1-------------------------------
Average Training Loss 0.034031741321086884
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04781971035934076
----------------------------EPOCH 58 PATCH 1----------------

Average Training Loss 0.02666524913711626
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04743323675016078
----------------------------EPOCH 95 PATCH 1-------------------------------
Average Training Loss 0.02692718345855103
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04730354576576047
----------------------------EPOCH 96 PATCH 1-------------------------------
Average Training Loss 0.026924150644755753
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047276601558778344
----------------------------EPOCH 97 PATCH 1-------------------------------
Average Training Loss 0.025962535108699173
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04719491266622776
----------------------------EPOCH 98 PATCH 1-------------------------------
Average Training Loss 0.026449195552067678
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04721089252611486
----------------------------EPOCH 99 PATCH 1-------------

Average Training Loss 0.04219153499016996
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.05011696350283739
----------------------------EPOCH 23 PATCH 2-------------------------------
Average Training Loss 0.04208901622256295
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.050088817026556993
----------------------------EPOCH 24 PATCH 2-------------------------------
Average Training Loss 0.042024877716283326
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.0500786493464214
----------------------------EPOCH 25 PATCH 2-------------------------------
Average Training Loss 0.04203059524297714
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05003261420784927
----------------------------EPOCH 26 PATCH 2-------------------------------
Average Training Loss 0.041962169110774994
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05001690184197775
----------------------------EPOCH 27 PATCH 2---------------

Average Training Loss 0.03783785466287957
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.048874080181121826
----------------------------EPOCH 64 PATCH 2-------------------------------
Average Training Loss 0.03797125474351351
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.04883870700510537
----------------------------EPOCH 65 PATCH 2-------------------------------
Average Training Loss 0.038022704178192576
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04878219453299918
----------------------------EPOCH 66 PATCH 2-------------------------------
Average Training Loss 0.037820009423083945
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.04875837157412273
----------------------------EPOCH 67 PATCH 2-------------------------------
Average Training Loss 0.03721651557039042
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04874567433101375
----------------------------EPOCH 68 PATCH 2--------------

Average Training Loss 0.03231924063846713
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04785111328450645
----------------------------EPOCH 105 PATCH 2-------------------------------
Average Training Loss 0.03333434688507533
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04793243437278562
----------------------------EPOCH 106 PATCH 2-------------------------------
Average Training Loss 0.03293213211610669
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04782453833556757
----------------------------EPOCH 107 PATCH 2-------------------------------
Average Training Loss 0.031734994994323765
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.047883355036014465
----------------------------EPOCH 108 PATCH 2-------------------------------
Average Training Loss 0.03181626562212334
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04779956835072215
----------------------------EPOCH 109 PATCH 2----------

Average Training Loss 0.04260918913317508
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05010629281765077
----------------------------EPOCH 8 PATCH 3-------------------------------
Average Training Loss 0.04225253802342493
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.050042182933993457
----------------------------EPOCH 9 PATCH 3-------------------------------
Average Training Loss 0.042256146425106486
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.049976857697091454
----------------------------EPOCH 10 PATCH 3-------------------------------
Average Training Loss 0.04176687571357508
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04992533747742816
----------------------------EPOCH 11 PATCH 3-------------------------------
Average Training Loss 0.041902273526934326
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.049850953788292116
----------------------------EPOCH 12 PATCH 3--------------

Average Training Loss 0.032584828981122034
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04740423109473252
----------------------------EPOCH 49 PATCH 3-------------------------------
Average Training Loss 0.0330410718795706
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047373669903452806
----------------------------EPOCH 50 PATCH 3-------------------------------
Average Training Loss 0.03227991916117121
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04721771653105573
----------------------------EPOCH 51 PATCH 3-------------------------------
Average Training Loss 0.03296936010239554
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04724880980282295
----------------------------EPOCH 52 PATCH 3-------------------------------
Average Training Loss 0.031426270965669974
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.0470162949910978
----------------------------EPOCH 53 PATCH 3----------------

Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04658866364781449
----------------------------EPOCH 89 PATCH 3-------------------------------
Average Training Loss 0.020731430989308436
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04689560430805858
----------------------------EPOCH 90 PATCH 3-------------------------------
Average Training Loss 0.02081551073027439
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04696945664359302
----------------------------EPOCH 91 PATCH 3-------------------------------
Average Training Loss 0.019534002623108566
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.046901051591082314
----------------------------EPOCH 92 PATCH 3-------------------------------
Average Training Loss 0.02005069941037991
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04695245332834197
BEST VAL ACC: 0.04601569728153508
----------------------------TEST RESULTS--------------------------

Average Training Loss 0.03961633255735773
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.049151660465612645
----------------------------EPOCH 34 PATCH 4-------------------------------
Average Training Loss 0.039380261766128855
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04910384736410001
----------------------------EPOCH 35 PATCH 4-------------------------------
Average Training Loss 0.039012487794532154
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04905460229734095
----------------------------EPOCH 36 PATCH 4-------------------------------
Average Training Loss 0.039139368128581126
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04899209447023345
----------------------------EPOCH 37 PATCH 4-------------------------------
Average Training Loss 0.038949234319514914
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04894056698171104
----------------------------EPOCH 38 PATCH 4------------

Average Training Loss 0.032304886301032835
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04742742747795291
----------------------------EPOCH 75 PATCH 4-------------------------------
Average Training Loss 0.03305225575067958
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04742925196159177
----------------------------EPOCH 76 PATCH 4-------------------------------
Average Training Loss 0.03228786703748781
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04737157211071107
----------------------------EPOCH 77 PATCH 4-------------------------------
Average Training Loss 0.032343760010648946
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.047390520572662354
----------------------------EPOCH 78 PATCH 4-------------------------------
Average Training Loss 0.03158039581335959
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.047456265949621435
----------------------------EPOCH 79 PATCH 4-------------

Average Training Loss 0.04358294252000871
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05027492162657947
----------------------------EPOCH 6 PATCH 5-------------------------------
Average Training Loss 0.04333159046583488
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05021955327289861
----------------------------EPOCH 7 PATCH 5-------------------------------
Average Training Loss 0.043216640465572234
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05016024083625979
----------------------------EPOCH 8 PATCH 5-------------------------------
Average Training Loss 0.043008311239422344
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.050096289413731274
----------------------------EPOCH 9 PATCH 5-------------------------------
Average Training Loss 0.04290726255686557
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.0500388610653761
----------------------------EPOCH 10 PATCH 5-------------------

Average Training Loss 0.03554534631185844
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0474960149788275
----------------------------EPOCH 47 PATCH 5-------------------------------
Average Training Loss 0.035430511795595046
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.047448049231273375
----------------------------EPOCH 48 PATCH 5-------------------------------
Average Training Loss 0.03576481244603141
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04733268080688104
----------------------------EPOCH 49 PATCH 5-------------------------------
Average Training Loss 0.03524247105004358
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04732084565046357
----------------------------EPOCH 50 PATCH 5-------------------------------
Average Training Loss 0.034886438827045625
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04715069183489171
----------------------------EPOCH 51 PATCH 5---------------

Average Training Loss 0.027418679695148936
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.045092222167224416
----------------------------EPOCH 88 PATCH 5-------------------------------
Average Training Loss 0.026394711410413024
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04502083470181721
----------------------------EPOCH 89 PATCH 5-------------------------------
Average Training Loss 0.026011625152142323
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04497869857927648
----------------------------EPOCH 90 PATCH 5-------------------------------
Average Training Loss 0.025840218682758143
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.044977859752934155
----------------------------EPOCH 91 PATCH 5-------------------------------
Average Training Loss 0.024720212536268545
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04495374749346477
----------------------------EPOCH 92 PATCH 5----------

Average Training Loss 0.04220341805551873
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.05001187615278291
----------------------------EPOCH 13 PATCH 6-------------------------------
Average Training Loss 0.04171355679386952
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.049900956270171375
----------------------------EPOCH 14 PATCH 6-------------------------------
Average Training Loss 0.041440868353257415
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04979406915059904
----------------------------EPOCH 15 PATCH 6-------------------------------
Average Training Loss 0.04131221673527702
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04968921004272089
----------------------------EPOCH 16 PATCH 6-------------------------------
Average Training Loss 0.04087454324862996
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.049574865073692506
----------------------------EPOCH 17 PATCH 6--------------

Average Training Loss 0.025360208798627385
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04487356325475181
----------------------------EPOCH 54 PATCH 6-------------------------------
Average Training Loss 0.024995299453129532
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04482850795838891
----------------------------EPOCH 55 PATCH 6-------------------------------
Average Training Loss 0.024381461324261836
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04464529054920848
----------------------------EPOCH 56 PATCH 6-------------------------------
Average Training Loss 0.023367880064933025
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04459789322643745
----------------------------EPOCH 57 PATCH 6-------------------------------
Average Training Loss 0.023533952285031804
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04446348620624077
----------------------------EPOCH 58 PATCH 6------------

Average Training Loss 0.04320110795927829
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05042618076975753
----------------------------EPOCH 1 PATCH 7-------------------------------
Average Training Loss 0.04312640428543091
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05038301683053738
----------------------------EPOCH 2 PATCH 7-------------------------------
Average Training Loss 0.04303534915212725
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.050343805696906115
----------------------------EPOCH 3 PATCH 7-------------------------------
Average Training Loss 0.042932733771253805
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05026832877135858
----------------------------EPOCH 4 PATCH 7-------------------------------
Average Training Loss 0.04289661274581659
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.050207920190764636
----------------------------EPOCH 5 PATCH 7-------------------

Average Training Loss 0.03539006314316734
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047545793579845896
----------------------------EPOCH 42 PATCH 7-------------------------------
Average Training Loss 0.035272601808680866
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0473913056094472
----------------------------EPOCH 43 PATCH 7-------------------------------
Average Training Loss 0.03470595290914911
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04731640001622642
----------------------------EPOCH 44 PATCH 7-------------------------------
Average Training Loss 0.0346413793378189
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047275976436894115
----------------------------EPOCH 45 PATCH 7-------------------------------
Average Training Loss 0.0344365689231724
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04714730018522681
----------------------------EPOCH 46 PATCH 7-----------------

Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04506984716508447
----------------------------EPOCH 82 PATCH 7-------------------------------
Average Training Loss 0.026326240208305297
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.0450898074522251
----------------------------EPOCH 83 PATCH 7-------------------------------
Average Training Loss 0.025025057927018306
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.044975910244918454
----------------------------EPOCH 84 PATCH 7-------------------------------
Average Training Loss 0.024090055742713272
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04484556579008335
----------------------------EPOCH 85 PATCH 7-------------------------------
Average Training Loss 0.024187710930089482
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04490750737306548
----------------------------EPOCH 86 PATCH 7-------------------------------
Average Training Loss 0

Average Training Loss 0.04176171294978408
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04956648844044383
----------------------------EPOCH 15 PATCH 8-------------------------------
Average Training Loss 0.041515081265910726
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.0494560453949905
----------------------------EPOCH 16 PATCH 8-------------------------------
Average Training Loss 0.041105461902305726
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04935910062092107
----------------------------EPOCH 17 PATCH 8-------------------------------
Average Training Loss 0.041064652751703734
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.049206476386000474
----------------------------EPOCH 18 PATCH 8-------------------------------
Average Training Loss 0.040802397200318634
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.049146141947769534
----------------------------EPOCH 19 PATCH 8------------

Average Training Loss 0.029516752992497116
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.044104350776207155
----------------------------EPOCH 56 PATCH 8-------------------------------
Average Training Loss 0.029334390444345163
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.04434009877646842
----------------------------EPOCH 57 PATCH 8-------------------------------
Average Training Loss 0.028559827108363637
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.04396911801361456
----------------------------EPOCH 58 PATCH 8-------------------------------
Average Training Loss 0.028439794346445897
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.0438060338904218
----------------------------EPOCH 59 PATCH 8-------------------------------
Average Training Loss 0.028382097355654983
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04362187443709955
----------------------------EPOCH 60 PATCH 8------------

Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.04203214877989234
----------------------------EPOCH 96 PATCH 8-------------------------------
Average Training Loss 0.017243981483529826
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.041957906106623206
----------------------------EPOCH 97 PATCH 8-------------------------------
Average Training Loss 0.016888597674789976
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04226073913457917
----------------------------EPOCH 98 PATCH 8-------------------------------
Average Training Loss 0.01764092863094611
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04187289243791162
----------------------------EPOCH 99 PATCH 8-------------------------------
Average Training Loss 0.0165939132576106
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.04191018095830592
----------------------------EPOCH 100 PATCH 8-------------------------------
Average Training Loss 0.

Average Training Loss 0.03361795793791286
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.047594301584290295
----------------------------EPOCH 32 PATCH 9-------------------------------
Average Training Loss 0.032836787952262844
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04752528085941222
----------------------------EPOCH 33 PATCH 9-------------------------------
Average Training Loss 0.03257784680997739
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047361019181042185
----------------------------EPOCH 34 PATCH 9-------------------------------
Average Training Loss 0.03185334441358926
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.047295861127899914
----------------------------EPOCH 35 PATCH 9-------------------------------
Average Training Loss 0.031945811737267696
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.047189548248197974
----------------------------EPOCH 36 PATCH 9-----------

Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.046787173282809374
BEST VAL ACC: 0.04594071899972311
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.7073170731707317
Recall: 0.7142857142857143
Precision: 0.7142857142857143
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0.]), tensor([0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 0.]), tensor([1., 1., 0., 0., 1., 1., 0., 1., 0.])]
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]), tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
----------------------------EPOCH 0 PATCH 10-------------------------------
Average Training Loss 0.04496838140194533
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.0508809729320247
----------------------------EPOCH 1 PATCH 10-------------------------------
Average Training Loss 0.04462478236585367
Validat

Average Training Loss 0.03502968973556503
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04789095826265288
----------------------------EPOCH 38 PATCH 10-------------------------------
Average Training Loss 0.03407902777439258
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.047793727095534165
----------------------------EPOCH 39 PATCH 10-------------------------------
Average Training Loss 0.03464832529425621
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04770310913644186
----------------------------EPOCH 40 PATCH 10-------------------------------
Average Training Loss 0.033745048964609864
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.047651001592961754
----------------------------EPOCH 41 PATCH 10-------------------------------
Average Training Loss 0.03345485904910525
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047622828948788526
----------------------------EPOCH 42 PATCH 10--------

Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04696609188870686
----------------------------EPOCH 78 PATCH 10-------------------------------
Average Training Loss 0.02221293640551997
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04700414000487909
----------------------------EPOCH 79 PATCH 10-------------------------------
Average Training Loss 0.020800076058653534
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04693749910447656
----------------------------EPOCH 80 PATCH 10-------------------------------
Average Training Loss 0.02201609032564476
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.046935203598766795
----------------------------EPOCH 81 PATCH 10-------------------------------
Average Training Loss 0.0214903574803325
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.0471347962937704
----------------------------EPOCH 82 PATCH 10-------------------------------
Average Training Loss 

Average Training Loss 0.03960611109362274
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04904960859112623
----------------------------EPOCH 32 PATCH 11-------------------------------
Average Training Loss 0.03999681477663947
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04898800791763678
----------------------------EPOCH 33 PATCH 11-------------------------------
Average Training Loss 0.03986001747553466
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04894254265761957
----------------------------EPOCH 34 PATCH 11-------------------------------
Average Training Loss 0.039214358466570495
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04895602929882887
----------------------------EPOCH 35 PATCH 11-------------------------------
Average Training Loss 0.03912075914320399
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.048841525868671694
----------------------------EPOCH 36 PATCH 11----------

Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04728709343003064
----------------------------EPOCH 72 PATCH 11-------------------------------
Average Training Loss 0.034001018363432806
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04730171110571885
----------------------------EPOCH 73 PATCH 11-------------------------------
Average Training Loss 0.03406530050713508
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04724061343728042
----------------------------EPOCH 74 PATCH 11-------------------------------
Average Training Loss 0.03409095609285792
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04715354268143817
----------------------------EPOCH 75 PATCH 11-------------------------------
Average Training Loss 0.033433760592683416
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04718974014607871
----------------------------EPOCH 76 PATCH 11-------------------------------
Average Training Los

Average Training Loss 0.043614654511701864
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05054993891134495
----------------------------EPOCH 1 PATCH 12-------------------------------
Average Training Loss 0.04371393777307917
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05053429341897732
----------------------------EPOCH 2 PATCH 12-------------------------------
Average Training Loss 0.043735092658488475
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.050507426261901855
----------------------------EPOCH 3 PATCH 12-------------------------------
Average Training Loss 0.04338295037140612
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.050472839576442066
----------------------------EPOCH 4 PATCH 12-------------------------------
Average Training Loss 0.04347561324229006
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.05044197454685118
----------------------------EPOCH 5 PATCH 12-------------

Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.048967070695830554
----------------------------EPOCH 41 PATCH 12-------------------------------
Average Training Loss 0.03641671296514449
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04890460648187777
----------------------------EPOCH 42 PATCH 12-------------------------------
Average Training Loss 0.03659522912052811
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04888009734270049
----------------------------EPOCH 43 PATCH 12-------------------------------
Average Training Loss 0.03626183774627623
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04880665569770627
----------------------------EPOCH 44 PATCH 12-------------------------------
Average Training Loss 0.03578618156616805
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04877777070533938
----------------------------EPOCH 45 PATCH 12-------------------------------
Average Training Loss

Average Training Loss 0.025534652784222463
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04836447936732594
----------------------------EPOCH 82 PATCH 12-------------------------------
Average Training Loss 0.024255638239813634
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04847766713398259
----------------------------EPOCH 83 PATCH 12-------------------------------
Average Training Loss 0.024668570668970952
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.048441349006280665
----------------------------EPOCH 84 PATCH 12-------------------------------
Average Training Loss 0.023870087305053335
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.048562325355483264
----------------------------EPOCH 85 PATCH 12-------------------------------
Average Training Loss 0.023637718414185476
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04851548846175031
----------------------------EPOCH 86 PATCH 12-----

Average Training Loss 0.04148625594670655
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04974962589217395
----------------------------EPOCH 28 PATCH 13-------------------------------
Average Training Loss 0.04129257895907418
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04972762160184907
----------------------------EPOCH 29 PATCH 13-------------------------------
Average Training Loss 0.04132066764792458
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04967658839574674
----------------------------EPOCH 30 PATCH 13-------------------------------
Average Training Loss 0.04101267945571024
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.0496438683533087
----------------------------EPOCH 31 PATCH 13-------------------------------
Average Training Loss 0.041054758258530354
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0495927813576489
----------------------------EPOCH 32 PATCH 13-------------

Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.048355527040435047
----------------------------EPOCH 68 PATCH 13-------------------------------
Average Training Loss 0.03694621803330594
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04829756515782054
----------------------------EPOCH 69 PATCH 13-------------------------------
Average Training Loss 0.036204674937685984
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.048279891653758726
----------------------------EPOCH 70 PATCH 13-------------------------------
Average Training Loss 0.036199007366524365
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.048239721030723756
----------------------------EPOCH 71 PATCH 13-------------------------------
Average Training Loss 0.035967273546046896
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04821307193942186
----------------------------EPOCH 72 PATCH 13-------------------------------
Average Training

Average Training Loss 0.029742087619226486
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04756237820881169
----------------------------EPOCH 109 PATCH 13-------------------------------
Average Training Loss 0.030626741772303817
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04744670594610819
----------------------------EPOCH 110 PATCH 13-------------------------------
Average Training Loss 0.030086327405249486
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04757022130780104
----------------------------EPOCH 111 PATCH 13-------------------------------
Average Training Loss 0.030736507450947997
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04760630247069568
----------------------------EPOCH 112 PATCH 13-------------------------------
Average Training Loss 0.031220725569568696
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.047570013418430236
----------------------------EPOCH 113 PATCH 13-

Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.05002867157866315
----------------------------EPOCH 20 PATCH 14-------------------------------
Average Training Loss 0.04259080623016983
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05002009141735914
----------------------------EPOCH 21 PATCH 14-------------------------------
Average Training Loss 0.04237951521502167
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05001728563773923
----------------------------EPOCH 22 PATCH 14-------------------------------
Average Training Loss 0.04243141205095854
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04992156203200177
----------------------------EPOCH 23 PATCH 14-------------------------------
Average Training Loss 0.04238344106029292
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.049907742477044825
----------------------------EPOCH 24 PATCH 14-------------------------------
Average Training Loss

Average Training Loss 0.03904454300149542
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04858780052603745
----------------------------EPOCH 61 PATCH 14-------------------------------
Average Training Loss 0.03841432318335674
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04851015166538518
----------------------------EPOCH 62 PATCH 14-------------------------------
Average Training Loss 0.038490448208128816
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04860673590404231
----------------------------EPOCH 63 PATCH 14-------------------------------
Average Training Loss 0.03883635655778353
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.048421744893236855
----------------------------EPOCH 64 PATCH 14-------------------------------
Average Training Loss 0.03832306847220562
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.0485221365603005
----------------------------EPOCH 65 PATCH 14-----------

Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0473833433011683
----------------------------EPOCH 101 PATCH 14-------------------------------
Average Training Loss 0.035050070981998915
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04740887734948135
----------------------------EPOCH 102 PATCH 14-------------------------------
Average Training Loss 0.03407114786935634
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04736627311241336
----------------------------EPOCH 103 PATCH 14-------------------------------
Average Training Loss 0.034448941955800914
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047435514810608655
----------------------------EPOCH 104 PATCH 14-------------------------------
Average Training Loss 0.03417510954571552
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.047444086249281724
----------------------------EPOCH 105 PATCH 14-------------------------------
Average Traini

Average Training Loss 0.04225353603480292
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04992807492977235
----------------------------EPOCH 7 PATCH 15-------------------------------
Average Training Loss 0.04211282962169804
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.049824880390632445
----------------------------EPOCH 8 PATCH 15-------------------------------
Average Training Loss 0.041666337823281524
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.049710824722197
----------------------------EPOCH 9 PATCH 15-------------------------------
Average Training Loss 0.04162929488010094
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04961387413304027
----------------------------EPOCH 10 PATCH 15-------------------------------
Average Training Loss 0.0415842849455896
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04949396557924224
----------------------------EPOCH 11 PATCH 15----------------

Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04484739536192359
----------------------------EPOCH 47 PATCH 15-------------------------------
Average Training Loss 0.030668040157341567
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.044725096807247254
----------------------------EPOCH 48 PATCH 15-------------------------------
Average Training Loss 0.02937210522225646
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04454999144484357
----------------------------EPOCH 49 PATCH 15-------------------------------
Average Training Loss 0.028915401120654872
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04440850455586503
----------------------------EPOCH 50 PATCH 15-------------------------------
Average Training Loss 0.028800334170704982
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04428336387727319
----------------------------EPOCH 51 PATCH 15-------------------------------
Average Training L

Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.0415774503859078
----------------------------EPOCH 87 PATCH 15-------------------------------
Average Training Loss 0.01751503339189975
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.04152932690411079
----------------------------EPOCH 88 PATCH 15-------------------------------
Average Training Loss 0.018279473312565537
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.041504072706873826
----------------------------EPOCH 89 PATCH 15-------------------------------
Average Training Loss 0.017493950783229264
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04154738929213547
----------------------------EPOCH 90 PATCH 15-------------------------------
Average Training Loss 0.017950608013350456
Validation Accuracy: 0.7560975609756098
Average Validation Loss: 0.04161085006667346
----------------------------EPOCH 91 PATCH 15-------------------------------
Average Training Lo

Average Training Loss 0.04064669985263074
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04942320468949109
----------------------------EPOCH 23 PATCH 16-------------------------------
Average Training Loss 0.04035916399271762
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.049376112658803054
----------------------------EPOCH 24 PATCH 16-------------------------------
Average Training Loss 0.04032977511648272
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04929812216177219
----------------------------EPOCH 25 PATCH 16-------------------------------
Average Training Loss 0.04062360464060893
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.0492517962688353
----------------------------EPOCH 26 PATCH 16-------------------------------
Average Training Loss 0.04043541384524987
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04915306771673807
----------------------------EPOCH 27 PATCH 16------------

Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04720119296050653
----------------------------EPOCH 63 PATCH 16-------------------------------
Average Training Loss 0.03400713706114253
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047183185088925245
----------------------------EPOCH 64 PATCH 16-------------------------------
Average Training Loss 0.033376603402563786
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047106645456174524
----------------------------EPOCH 65 PATCH 16-------------------------------
Average Training Loss 0.03266201681289516
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047131110982197084
----------------------------EPOCH 66 PATCH 16-------------------------------
Average Training Loss 0.032593422370855926
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04710623694629204
----------------------------EPOCH 67 PATCH 16-------------------------------
Average Training 

Average Training Loss 0.04469426623621925
Validation Accuracy: 0.3902439024390244
Average Validation Loss: 0.05119479865562625
----------------------------EPOCH 4 PATCH 17-------------------------------
Average Training Loss 0.04472719694747299
Validation Accuracy: 0.3902439024390244
Average Validation Loss: 0.05107672476186985
----------------------------EPOCH 5 PATCH 17-------------------------------
Average Training Loss 0.044704764959264974
Validation Accuracy: 0.3902439024390244
Average Validation Loss: 0.05098798507597388
----------------------------EPOCH 6 PATCH 17-------------------------------
Average Training Loss 0.04423233817835323
Validation Accuracy: 0.4146341463414634
Average Validation Loss: 0.05089941838892495
----------------------------EPOCH 7 PATCH 17-------------------------------
Average Training Loss 0.04430744088575488
Validation Accuracy: 0.4634146341463415
Average Validation Loss: 0.0507935692624348
----------------------------EPOCH 8 PATCH 17-----------------

Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04745092479194083
----------------------------EPOCH 44 PATCH 17-------------------------------
Average Training Loss 0.03784718828611686
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04736513771661898
----------------------------EPOCH 45 PATCH 17-------------------------------
Average Training Loss 0.037783028405220784
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04729953190175498
----------------------------EPOCH 46 PATCH 17-------------------------------
Average Training Loss 0.03803139521938856
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04714680154149125
----------------------------EPOCH 47 PATCH 17-------------------------------
Average Training Loss 0.037681374759947664
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.047106673077839174
----------------------------EPOCH 48 PATCH 17-------------------------------
Average Training Lo

Average Training Loss 0.0304056139143764
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04335916333082246
----------------------------EPOCH 85 PATCH 17-------------------------------
Average Training Loss 0.03021971087475292
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.043288530372991796
----------------------------EPOCH 86 PATCH 17-------------------------------
Average Training Loss 0.029876844010880737
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04313719854122255
----------------------------EPOCH 87 PATCH 17-------------------------------
Average Training Loss 0.029412441810623545
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04308903508070039
----------------------------EPOCH 88 PATCH 17-------------------------------
Average Training Loss 0.030325935145870585
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04307830915218446
----------------------------EPOCH 89 PATCH 17---------

Average Training Loss 0.02315589831378616
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.0425855081255843
----------------------------EPOCH 125 PATCH 17-------------------------------
Average Training Loss 0.023873433715007344
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.042637772676421375
----------------------------EPOCH 126 PATCH 17-------------------------------
Average Training Loss 0.0234891592112721
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04268848314517882
----------------------------EPOCH 127 PATCH 17-------------------------------
Average Training Loss 0.02454547022209793
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04290798815285287
----------------------------EPOCH 128 PATCH 17-------------------------------
Average Training Loss 0.022892918682000676
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04282867181591871
BEST VAL ACC: 0.04229204538391858
------------------

Average Training Loss 0.041905966083534425
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05020402844359235
----------------------------EPOCH 34 PATCH 18-------------------------------
Average Training Loss 0.04155552289525016
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05019346533752069
----------------------------EPOCH 35 PATCH 18-------------------------------
Average Training Loss 0.041604674497588735
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05017196405224684
----------------------------EPOCH 36 PATCH 18-------------------------------
Average Training Loss 0.04156830955724247
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05015262452567496
----------------------------EPOCH 37 PATCH 18-------------------------------
Average Training Loss 0.041201375791283905
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05011405450541798
----------------------------EPOCH 38 PATCH 18---------

Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04933281933389059
----------------------------EPOCH 74 PATCH 18-------------------------------
Average Training Loss 0.03806098098637628
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04932715834640875
----------------------------EPOCH 75 PATCH 18-------------------------------
Average Training Loss 0.037128815763309356
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04929752175400897
----------------------------EPOCH 76 PATCH 18-------------------------------
Average Training Loss 0.03743400730070521
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.049263804424099805
----------------------------EPOCH 77 PATCH 18-------------------------------
Average Training Loss 0.03727126231447595
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.04928847783949317
----------------------------EPOCH 78 PATCH 18-------------------------------
Average Training Los

Average Training Loss 0.031952072667782425
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.04880328876216237
----------------------------EPOCH 115 PATCH 18-------------------------------
Average Training Loss 0.033190459990110555
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04881526638821858
----------------------------EPOCH 116 PATCH 18-------------------------------
Average Training Loss 0.03192545229294261
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.048827161149280825
----------------------------EPOCH 117 PATCH 18-------------------------------
Average Training Loss 0.030487716564389526
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.048780547409522826
----------------------------EPOCH 118 PATCH 18-------------------------------
Average Training Loss 0.031387727707624435
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04881732347534924
----------------------------EPOCH 119 PATCH 18-

Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.049532147442422264
----------------------------EPOCH 18 PATCH 19-------------------------------
Average Training Loss 0.04127149650307953
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04945250400682775
----------------------------EPOCH 19 PATCH 19-------------------------------
Average Training Loss 0.04120583453627884
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04937398433685303
----------------------------EPOCH 20 PATCH 19-------------------------------
Average Training Loss 0.0406810284638014
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04929257892980808
----------------------------EPOCH 21 PATCH 19-------------------------------
Average Training Loss 0.040753956334512745
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04922478228080564
----------------------------EPOCH 22 PATCH 19-------------------------------
Average Training Loss

Average Training Loss 0.03250072359061632
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04629109109320292
----------------------------EPOCH 59 PATCH 19-------------------------------
Average Training Loss 0.03158819730408856
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04626084246286532
----------------------------EPOCH 60 PATCH 19-------------------------------
Average Training Loss 0.031765663843663014
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.046256331408896095
----------------------------EPOCH 61 PATCH 19-------------------------------
Average Training Loss 0.0315361737472112
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04611436768275935
----------------------------EPOCH 62 PATCH 19-------------------------------
Average Training Loss 0.031087415995167903
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.046073358233382065
----------------------------EPOCH 63 PATCH 19---------

Average Training Loss 0.04379683647487984
Validation Accuracy: 0.4878048780487805
Average Validation Loss: 0.05066005049682245
----------------------------EPOCH 6 PATCH 20-------------------------------
Average Training Loss 0.043859534576290944
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.050620666364344154
----------------------------EPOCH 7 PATCH 20-------------------------------
Average Training Loss 0.043992625396759785
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.050582577542560854
----------------------------EPOCH 8 PATCH 20-------------------------------
Average Training Loss 0.04373707634503724
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05054796905052371
----------------------------EPOCH 9 PATCH 20-------------------------------
Average Training Loss 0.043534548800499714
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05050201677694553
----------------------------EPOCH 10 PATCH 20-----------

Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04935565227415503
----------------------------EPOCH 46 PATCH 20-------------------------------
Average Training Loss 0.040721430275283875
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.0493259880600906
----------------------------EPOCH 47 PATCH 20-------------------------------
Average Training Loss 0.04072433331462204
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04927724309083892
----------------------------EPOCH 48 PATCH 20-------------------------------
Average Training Loss 0.040961343611850116
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.049254751786953065
----------------------------EPOCH 49 PATCH 20-------------------------------
Average Training Loss 0.04046342888327896
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.0492300536574387
----------------------------EPOCH 50 PATCH 20-------------------------------
Average Training Loss

Average Training Loss 0.03700984342664969
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047998046002736904
----------------------------EPOCH 87 PATCH 20-------------------------------
Average Training Loss 0.03788079101531232
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047957747447781446
----------------------------EPOCH 88 PATCH 20-------------------------------
Average Training Loss 0.036827810597224314
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04797362845118453
----------------------------EPOCH 89 PATCH 20-------------------------------
Average Training Loss 0.03749827784104425
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04791103194399578
----------------------------EPOCH 90 PATCH 20-------------------------------
Average Training Loss 0.03674935836528168
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0478984757167537
----------------------------EPOCH 91 PATCH 20----------

Average Training Loss 0.03332043481899089
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04720377340549376
----------------------------EPOCH 127 PATCH 20-------------------------------
Average Training Loss 0.033043189310148116
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04726048213679616
----------------------------EPOCH 128 PATCH 20-------------------------------
Average Training Loss 0.03368927544501961
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.0474025374505578
----------------------------EPOCH 129 PATCH 20-------------------------------
Average Training Loss 0.03309574497283482
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04730828215436238
----------------------------EPOCH 130 PATCH 20-------------------------------
Average Training Loss 0.03372472414716345
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04718382765607136
----------------------------EPOCH 131 PATCH 20-------

Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05050856892655536
----------------------------EPOCH 9 PATCH 21-------------------------------
Average Training Loss 0.04395634949695869
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.050461738574795606
----------------------------EPOCH 10 PATCH 21-------------------------------
Average Training Loss 0.04358616874354784
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05042502065984214
----------------------------EPOCH 11 PATCH 21-------------------------------
Average Training Loss 0.04351860394731897
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.050381603764324656
----------------------------EPOCH 12 PATCH 21-------------------------------
Average Training Loss 0.04352737547921353
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05033970024527573
----------------------------EPOCH 13 PATCH 21-------------------------------
Average Training Loss

Average Training Loss 0.03912186818044694
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.048724267540908445
----------------------------EPOCH 50 PATCH 21-------------------------------
Average Training Loss 0.03859012683884042
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.0486862862982401
----------------------------EPOCH 51 PATCH 21-------------------------------
Average Training Loss 0.03880724396373405
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04860974957303303
----------------------------EPOCH 52 PATCH 21-------------------------------
Average Training Loss 0.038600472763913575
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.048581040487056824
----------------------------EPOCH 53 PATCH 21-------------------------------
Average Training Loss 0.0387846866592032
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.048528224956698535
----------------------------EPOCH 54 PATCH 21----------

Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.046904997127812084
----------------------------EPOCH 90 PATCH 21-------------------------------
Average Training Loss 0.03190348822562421
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04692922859657102
----------------------------EPOCH 91 PATCH 21-------------------------------
Average Training Loss 0.0320105431754081
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04686808440743423
----------------------------EPOCH 92 PATCH 21-------------------------------
Average Training Loss 0.03209840610134797
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04684800345723222
----------------------------EPOCH 93 PATCH 21-------------------------------
Average Training Loss 0.03158227665746798
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04681895854996472
----------------------------EPOCH 94 PATCH 21-------------------------------
Average Training Loss 

Average Training Loss 0.043793034602384096
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.05063049822318845
----------------------------EPOCH 3 PATCH 22-------------------------------
Average Training Loss 0.043874984271213655
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.050570704588075964
----------------------------EPOCH 4 PATCH 22-------------------------------
Average Training Loss 0.04357319984768258
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.050508829151711815
----------------------------EPOCH 5 PATCH 22-------------------------------
Average Training Loss 0.04346580925534983
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.05042798490059085
----------------------------EPOCH 6 PATCH 22-------------------------------
Average Training Loss 0.0434484246080039
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.050356633779479236
----------------------------EPOCH 7 PATCH 22-------------

Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.047749879883556834
----------------------------EPOCH 43 PATCH 22-------------------------------
Average Training Loss 0.036821928913476035
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04766749027298718
----------------------------EPOCH 44 PATCH 22-------------------------------
Average Training Loss 0.036995350581700684
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.047595291602902295
----------------------------EPOCH 45 PATCH 22-------------------------------
Average Training Loss 0.03653809299967328
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04755400593687848
----------------------------EPOCH 46 PATCH 22-------------------------------
Average Training Loss 0.036934837027162805
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04746380957161508
----------------------------EPOCH 47 PATCH 22-------------------------------
Average Training 

Average Training Loss 0.030146952168863327
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04659181397135665
----------------------------EPOCH 84 PATCH 22-------------------------------
Average Training Loss 0.029851669415098724
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04659502971463087
----------------------------EPOCH 85 PATCH 22-------------------------------
Average Training Loss 0.02945355325937271
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04667664010350297
----------------------------EPOCH 86 PATCH 22-------------------------------
Average Training Loss 0.03088044308003832
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04664874658351991
----------------------------EPOCH 87 PATCH 22-------------------------------
Average Training Loss 0.029412133832935426
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04668531621374735
----------------------------EPOCH 88 PATCH 22---------

Average Training Loss 0.04267251943467093
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04988828228741157
----------------------------EPOCH 26 PATCH 23-------------------------------
Average Training Loss 0.04269921510923104
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04985864278746814
----------------------------EPOCH 27 PATCH 23-------------------------------
Average Training Loss 0.04248145943293806
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04981368925513291
----------------------------EPOCH 28 PATCH 23-------------------------------
Average Training Loss 0.04252246767282486
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04977084514571399
----------------------------EPOCH 29 PATCH 23-------------------------------
Average Training Loss 0.04225895712610151
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.049713136219396825
----------------------------EPOCH 30 PATCH 23-----------

Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04796896765871746
----------------------------EPOCH 66 PATCH 23-------------------------------
Average Training Loss 0.038886917296980246
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.047938061923515504
----------------------------EPOCH 67 PATCH 23-------------------------------
Average Training Loss 0.038516189719809864
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04789000749588013
----------------------------EPOCH 68 PATCH 23-------------------------------
Average Training Loss 0.03900897392972571
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.047839138565993894
----------------------------EPOCH 69 PATCH 23-------------------------------
Average Training Loss 0.03859442582384485
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04779359916361367
----------------------------EPOCH 70 PATCH 23-------------------------------
Average Training L

Average Training Loss 0.03489080272981378
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.046045466167170826
----------------------------EPOCH 107 PATCH 23-------------------------------
Average Training Loss 0.03455530955898957
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.045976534122374
----------------------------EPOCH 108 PATCH 23-------------------------------
Average Training Loss 0.03413914913525347
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.045964246842919325
----------------------------EPOCH 109 PATCH 23-------------------------------
Average Training Loss 0.03423342163689801
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.0459421422423386
----------------------------EPOCH 110 PATCH 23-------------------------------
Average Training Loss 0.03405422375338976
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04588305514033248
----------------------------EPOCH 111 PATCH 23--------

Average Training Loss 0.033184947537594156
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.045342417751870506
----------------------------EPOCH 147 PATCH 23-------------------------------
Average Training Loss 0.031545814981714625
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.045451725401529454
----------------------------EPOCH 148 PATCH 23-------------------------------
Average Training Loss 0.031512066905127194
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04544591322177794
----------------------------EPOCH 149 PATCH 23-------------------------------
Average Training Loss 0.031930312575375444
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04548972118191603
----------------------------EPOCH 150 PATCH 23-------------------------------
Average Training Loss 0.031635653716130335
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04537075612603164
----------------------------EPOCH 151 PATCH 23

Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04934961621354266
----------------------------EPOCH 24 PATCH 24-------------------------------
Average Training Loss 0.04080168328812865
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04927099768708392
----------------------------EPOCH 25 PATCH 24-------------------------------
Average Training Loss 0.04071691204778484
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.049220140387372276
----------------------------EPOCH 26 PATCH 24-------------------------------
Average Training Loss 0.04029159687581609
Validation Accuracy: 0.5121951219512195
Average Validation Loss: 0.049142577299257605
----------------------------EPOCH 27 PATCH 24-------------------------------
Average Training Loss 0.0404477695949742
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04906072413049093
----------------------------EPOCH 28 PATCH 24-------------------------------
Average Training Loss

Average Training Loss 0.032181429936260476
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04609133557575505
----------------------------EPOCH 65 PATCH 24-------------------------------
Average Training Loss 0.03198557908906311
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04599168242477789
----------------------------EPOCH 66 PATCH 24-------------------------------
Average Training Loss 0.03206013179704791
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04592959328395564
----------------------------EPOCH 67 PATCH 24-------------------------------
Average Training Loss 0.03136547758686738
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.045868242659219884
----------------------------EPOCH 68 PATCH 24-------------------------------
Average Training Loss 0.03112385501382781
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.045780764847266966
----------------------------EPOCH 69 PATCH 24---------

Average Training Loss 0.023489537297702225
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.044754609829042016
----------------------------EPOCH 105 PATCH 24-------------------------------
Average Training Loss 0.023805575414759216
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04479861695591996
----------------------------EPOCH 106 PATCH 24-------------------------------
Average Training Loss 0.022732821827540634
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.0446954616686193
----------------------------EPOCH 107 PATCH 24-------------------------------
Average Training Loss 0.02311979228111564
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04471323548293695
----------------------------EPOCH 108 PATCH 24-------------------------------
Average Training Loss 0.021046944054179503
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.044704092711937135
----------------------------EPOCH 109 PATCH 24--

Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04867456017470941
----------------------------EPOCH 29 PATCH 25-------------------------------
Average Training Loss 0.04041965574514671
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04859646791365088
----------------------------EPOCH 30 PATCH 25-------------------------------
Average Training Loss 0.04028200345938323
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04851248351539054
----------------------------EPOCH 31 PATCH 25-------------------------------
Average Training Loss 0.039938983858608806
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.04848003823582719
----------------------------EPOCH 32 PATCH 25-------------------------------
Average Training Loss 0.04025577314075877
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.048395334220514064
----------------------------EPOCH 33 PATCH 25-------------------------------
Average Training Los

Average Training Loss 0.03386692116495039
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04557304992908385
----------------------------EPOCH 70 PATCH 25-------------------------------
Average Training Loss 0.03347712694132914
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04554759438444928
----------------------------EPOCH 71 PATCH 25-------------------------------
Average Training Loss 0.03389395284848135
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04542371121848502
----------------------------EPOCH 72 PATCH 25-------------------------------
Average Training Loss 0.03308946462195428
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04543011653714064
----------------------------EPOCH 73 PATCH 25-------------------------------
Average Training Loss 0.03328469694882143
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.045319230091281054
----------------------------EPOCH 74 PATCH 25-----------

Average Training Loss 0.027136946432903163
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.044333979850862085
----------------------------EPOCH 110 PATCH 25-------------------------------
Average Training Loss 0.027267536911808075
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04439855348773119
----------------------------EPOCH 111 PATCH 25-------------------------------
Average Training Loss 0.028364042644617986
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.044372805734960045
----------------------------EPOCH 112 PATCH 25-------------------------------
Average Training Loss 0.027165598129151296
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.044432022222658484
----------------------------EPOCH 113 PATCH 25-------------------------------
Average Training Loss 0.027289091258263978
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04443822255948695
----------------------------EPOCH 114 PATCH 2

Average Training Loss 0.040938089250541126
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04883083919199502
----------------------------EPOCH 31 PATCH 26-------------------------------
Average Training Loss 0.040690730585426584
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04879365897760159
----------------------------EPOCH 32 PATCH 26-------------------------------
Average Training Loss 0.04045896048917145
Validation Accuracy: 0.5365853658536586
Average Validation Loss: 0.04870862931739993
----------------------------EPOCH 33 PATCH 26-------------------------------
Average Training Loss 0.04035874395096888
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04863425580466666
----------------------------EPOCH 34 PATCH 26-------------------------------
Average Training Loss 0.04043470188731053
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.048579800419691135
----------------------------EPOCH 35 PATCH 26---------

Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.045722468597132984
----------------------------EPOCH 71 PATCH 26-------------------------------
Average Training Loss 0.03471743248280932
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.045617045425787206
----------------------------EPOCH 72 PATCH 26-------------------------------
Average Training Loss 0.034409138694649836
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.045614188764153456
----------------------------EPOCH 73 PATCH 26-------------------------------
Average Training Loss 0.03458737038442346
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.045505687957856714
----------------------------EPOCH 74 PATCH 26-------------------------------
Average Training Loss 0.03404462655059627
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04539187361554402
----------------------------EPOCH 75 PATCH 26-------------------------------
Average Training 

Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04348191255476416
----------------------------EPOCH 111 PATCH 26-------------------------------
Average Training Loss 0.028705897267724646
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04344972750035728
----------------------------EPOCH 112 PATCH 26-------------------------------
Average Training Loss 0.027969761461508078
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04349294959045038
----------------------------EPOCH 113 PATCH 26-------------------------------
Average Training Loss 0.028123473962310883
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04350928562443431
----------------------------EPOCH 114 PATCH 26-------------------------------
Average Training Loss 0.028157322072103374
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04346885623001471
----------------------------EPOCH 115 PATCH 26-------------------------------
Average Train

### Mutlimodal without attention

In [16]:
# Set Seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)


MRI_PET_patch_models = []

# For each patch location, concatenate outputs of trained patch models for each modality and learn local features
for i in range(27):
    MRI_patch_model = MRI_models[i]
    PET_patch_model = PET_models[i]
    MRI_PET_patch_model = multimodal_patch_CNN((MRI_patch_model,PET_patch_model)).to(device)
    
    # Freeze MRI and PET patch model layers (we only want the feature maps)
    for name, param in MRI_patch_model.named_parameters():
            param.requires_grad = False 
            
    for name, param in PET_patch_model.named_parameters():
            param.requires_grad = False 
    
    # Instantiate dataloaders
    MRI_PET_train_dataloaders = patch_train_loaders_MRI_PET[i]
    MRI_PET_val_dataloaders = patch_val_loaders_MRI_PET[i]
    MRI_PET_test_dataloaders = patch_test_loaders_MRI_PET[i]
    
    best_val_loss = np.inf
    patience = 15
    no_improvement = 0
    
    # Binary Cross Entropy Loss
    error = nn.BCELoss()

    # SGD Optimizer
    optimiser = SGD(MRI_PET_patch_model.parameters(), lr=0.0001, momentum=0.99)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)
    
    for epoch in range(500):
        print(f"----------------------------EPOCH {epoch} PATCH {i}-------------------------------")
        MRI_PET_patch_model.train()
        train_model(MRI_PET_patch_model, MRI_PET_train_dataloaders, optimiser, error, scheduler, True)

        MRI_PET_patch_model.eval()
        val_loss = validate_model(MRI_PET_patch_model, MRI_PET_val_dataloaders, error, True)
        
        # Save model weights if improvement seen
        # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(MRI_PET_patch_model.state_dict(), f"./patch_temp/multimodal_PATCH_TEMP{i}")
        else:
            no_improvement += 1
            if no_improvement <= patience:
                continue
            else:
                print(f"BEST VAL ACC: {best_val_loss}")
                break
         
    print("----------------------------TEST RESULTS-------------------------------")
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP{i}"))
    MRI_PET_patch_model.eval()
    evaluate_model(MRI_PET_patch_model, MRI_PET_test_dataloaders, error, True)
    MRI_PET_patch_models.append(MRI_PET_patch_model)

----------------------------EPOCH 0 PATCH 0-------------------------------
Average Training Loss 0.046653531369615774
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.049145051618901696
----------------------------EPOCH 1 PATCH 0-------------------------------
Average Training Loss 0.04149744390952782
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04714423708799409
----------------------------EPOCH 2 PATCH 0-------------------------------
Average Training Loss 0.03467043330434893
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04541407561883694
----------------------------EPOCH 3 PATCH 0-------------------------------
Average Training Loss 0.026735178332348338
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.044429315299522584
----------------------------EPOCH 4 PATCH 0-------------------------------
Average Training Loss 0.021483917491602115
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.044

Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.05203052046822339
----------------------------EPOCH 17 PATCH 1-------------------------------
Average Training Loss 0.019477558826081088
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.05257413300072274
----------------------------EPOCH 18 PATCH 1-------------------------------
Average Training Loss 0.02123911318476083
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.052279433099234975
----------------------------EPOCH 19 PATCH 1-------------------------------
Average Training Loss 0.018683635248023956
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.054183374090892515
----------------------------EPOCH 20 PATCH 1-------------------------------
Average Training Loss 0.020253878850184502
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.052856056428537135
----------------------------EPOCH 21 PATCH 1-------------------------------
Average Training Loss

Average Training Loss 0.02345179293121471
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04718594507473271
----------------------------EPOCH 7 PATCH 3-------------------------------
Average Training Loss 0.020745083506478638
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04767885949553513
----------------------------EPOCH 8 PATCH 3-------------------------------
Average Training Loss 0.019042784042778562
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04821302469183759
----------------------------EPOCH 9 PATCH 3-------------------------------
Average Training Loss 0.017599522731587536
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04861129202493807
----------------------------EPOCH 10 PATCH 3-------------------------------
Average Training Loss 0.01570639551663008
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.0499592001845197
----------------------------EPOCH 11 PATCH 3------------------

Average Training Loss 0.0238653371935008
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05352873918486804
BEST VAL ACC: 0.048598687823225815
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.8048780487804879
Recall: 0.6666666666666666
Precision: 0.9333333333333333
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.]), tensor([0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.]), tensor([0., 0., 0., 0., 1., 1., 1., 1., 0.])]
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]), tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
----------------------------EPOCH 0 PATCH 5-------------------------------
Average Training Loss 0.04137889821021283
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05049032409016679
----------------------------EPOCH 1 PATCH 5-------------------------------
Average T

Average Training Loss 0.01345612936210437
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04543773575526912
----------------------------EPOCH 13 PATCH 6-------------------------------
Average Training Loss 0.013246084151209378
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04532000640543496
----------------------------EPOCH 14 PATCH 6-------------------------------
Average Training Loss 0.01124993026195491
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04572672378726122
----------------------------EPOCH 15 PATCH 6-------------------------------
Average Training Loss 0.011107133273951342
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04603245200180426
----------------------------EPOCH 16 PATCH 6-------------------------------
Average Training Loss 0.011385197735956459
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.0462440121464613
----------------------------EPOCH 17 PATCH 6---------------

Average Training Loss 0.01645882923888867
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04962265491485596
BEST VAL ACC: 0.04618191282923629
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.7073170731707317
Recall: 0.6666666666666666
Precision: 0.7368421052631579
[tensor([1., 0., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1.]), tensor([0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0.]), tensor([0., 0., 0., 0., 1., 0., 1., 1., 0.])]
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]), tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
----------------------------EPOCH 0 PATCH 8-------------------------------
Average Training Loss 0.0384050616231121
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05022357585953503
----------------------------EPOCH 1 PATCH 8-------------------------------
Average Tr

Average Training Loss 0.014658171637747132
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04853664084178645
----------------------------EPOCH 12 PATCH 9-------------------------------
Average Training Loss 0.013671729866354192
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.04957372240903901
----------------------------EPOCH 13 PATCH 9-------------------------------
Average Training Loss 0.013533230932032476
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05001271788666888
----------------------------EPOCH 14 PATCH 9-------------------------------
Average Training Loss 0.01419061703271553
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05025548033597993
----------------------------EPOCH 15 PATCH 9-------------------------------
Average Training Loss 0.012352839677182377
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05148725974850538
----------------------------EPOCH 16 PATCH 9-------------

Average Training Loss 0.04036203319909143
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04894810478861739
----------------------------EPOCH 2 PATCH 11-------------------------------
Average Training Loss 0.0371227182570051
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.048071717343679286
----------------------------EPOCH 3 PATCH 11-------------------------------
Average Training Loss 0.03615945852437957
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04731174940016211
----------------------------EPOCH 4 PATCH 11-------------------------------
Average Training Loss 0.033332182002849264
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.046767511018892614
----------------------------EPOCH 5 PATCH 11-------------------------------
Average Training Loss 0.03237731125755388
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.046565548675816235
----------------------------EPOCH 6 PATCH 11--------------

Average Training Loss 0.015423520910935323
Validation Accuracy: 0.5609756097560976
Average Validation Loss: 0.05775943035032691
----------------------------EPOCH 18 PATCH 12-------------------------------
Average Training Loss 0.013319885968917707
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.058117282099840115
----------------------------EPOCH 19 PATCH 12-------------------------------
Average Training Loss 0.01316424469906287
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.058351259406020005
----------------------------EPOCH 20 PATCH 12-------------------------------
Average Training Loss 0.013863486314161878
Validation Accuracy: 0.5853658536585366
Average Validation Loss: 0.05995198720838965
BEST VAL ACC: 0.04772568330532167
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.6829268292682927
Recall: 0.7619047619047619
Precision: 0.6666666666666666
[tensor([1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.

Average Training Loss 0.034637660337764706
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04702136138590371
----------------------------EPOCH 8 PATCH 14-------------------------------
Average Training Loss 0.03251860022056298
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047752467597403175
----------------------------EPOCH 9 PATCH 14-------------------------------
Average Training Loss 0.03156055754325429
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04782954803327235
----------------------------EPOCH 10 PATCH 14-------------------------------
Average Training Loss 0.033051568404084346
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.04881416733671979
----------------------------EPOCH 11 PATCH 14-------------------------------
Average Training Loss 0.03046608801747932
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.048367754715244946
----------------------------EPOCH 12 PATCH 14----------

Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04684194340938475
----------------------------EPOCH 21 PATCH 15-------------------------------
Average Training Loss 0.010936751641088822
Validation Accuracy: 0.7804878048780488
Average Validation Loss: 0.04575923884787211
----------------------------EPOCH 22 PATCH 15-------------------------------
Average Training Loss 0.012835534228408923
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.047898296902819375
----------------------------EPOCH 23 PATCH 15-------------------------------
Average Training Loss 0.012989138512582075
Validation Accuracy: 0.7317073170731707
Average Validation Loss: 0.04633507495973169
----------------------------EPOCH 24 PATCH 15-------------------------------
Average Training Loss 0.009256047379897266
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04774867470671491
----------------------------EPOCH 25 PATCH 15-------------------------------
Average Training 

Average Training Loss 0.02758231649144751
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.0424823753717469
----------------------------EPOCH 6 PATCH 17-------------------------------
Average Training Loss 0.02688751479641336
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04304841960348734
----------------------------EPOCH 7 PATCH 17-------------------------------
Average Training Loss 0.025898582653188316
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04325381793626925
----------------------------EPOCH 8 PATCH 17-------------------------------
Average Training Loss 0.02566958206598876
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04345141314878696
----------------------------EPOCH 9 PATCH 17-------------------------------
Average Training Loss 0.023848446696752408
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04363132250018236
----------------------------EPOCH 10 PATCH 17---------------

Average Training Loss 0.05150498669655597
Validation Accuracy: 0.43902439024390244
Average Validation Loss: 0.05149905419931179
----------------------------EPOCH 1 PATCH 19-------------------------------
Average Training Loss 0.04797954368786734
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04956214747777799
----------------------------EPOCH 2 PATCH 19-------------------------------
Average Training Loss 0.04151868673621631
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04699830165723475
----------------------------EPOCH 3 PATCH 19-------------------------------
Average Training Loss 0.03524597373897912
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.04550470375433201
----------------------------EPOCH 4 PATCH 19-------------------------------
Average Training Loss 0.030769491354461578
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04468550187785451
----------------------------EPOCH 5 PATCH 19---------------

Average Training Loss 0.030677397231586644
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.05106640588946459
----------------------------EPOCH 17 PATCH 20-------------------------------
Average Training Loss 0.028573045049045908
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05165505554617905
----------------------------EPOCH 18 PATCH 20-------------------------------
Average Training Loss 0.02803543573398082
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.052220312560476906
----------------------------EPOCH 19 PATCH 20-------------------------------
Average Training Loss 0.026312649188960185
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.05175894789579438
----------------------------EPOCH 20 PATCH 20-------------------------------
Average Training Loss 0.027612599070932043
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.05137979984283447
----------------------------EPOCH 21 PATCH 20-------

Average Training Loss 0.02857055337946923
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.048156146596117715
----------------------------EPOCH 8 PATCH 22-------------------------------
Average Training Loss 0.027810715810685862
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.04909088844206275
----------------------------EPOCH 9 PATCH 22-------------------------------
Average Training Loss 0.027687622875463766
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.049238710868649366
----------------------------EPOCH 10 PATCH 22-------------------------------
Average Training Loss 0.026781153391863478
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.05048201723796565
----------------------------EPOCH 11 PATCH 22-------------------------------
Average Training Loss 0.0243155053038089
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.050583999331404526
----------------------------EPOCH 12 PATCH 22---------

Average Training Loss 0.03140748597559382
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.048330071495800486
BEST VAL ACC: 0.04649055294874238
----------------------------TEST RESULTS-------------------------------
Test Accuracy: 0.6829268292682927
Recall: 0.5238095238095238
Precision: 0.7857142857142857
[tensor([1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0.]), tensor([0., 0., 0., 1., 1., 0., 1., 1., 1.])]
[tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]), tensor([0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0.]), tensor([0., 0., 0., 1., 1., 1., 1., 1., 0.])]
----------------------------EPOCH 0 PATCH 24-------------------------------
Average Training Loss 0.051422397377061065
Validation Accuracy: 0.34146341463414637
Average Validation Loss: 0.051974844641801785
----------------------------EPOCH 1 PATCH 24-------------------------------
Ave

Average Training Loss 0.023993303541277277
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.045705440567760935
----------------------------EPOCH 12 PATCH 25-------------------------------
Average Training Loss 0.024800153571318408
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.046370126852175085
----------------------------EPOCH 13 PATCH 25-------------------------------
Average Training Loss 0.022762024530866107
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.047053343639141175
----------------------------EPOCH 14 PATCH 25-------------------------------
Average Training Loss 0.02378240032274215
Validation Accuracy: 0.7073170731707317
Average Validation Loss: 0.04712112356976765
----------------------------EPOCH 15 PATCH 25-------------------------------
Average Training Loss 0.024421406787682752
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.04762030083958695
----------------------------EPOCH 16 PATCH 25-----

# Stage 3. Patch Fusion

In [17]:
# Stacking all 27 patch based datasets to create a single subject level dataset
# We will now have a subject level dataset where each row is a subject and their 27 patch images
# e.g row 1 would be: [patch_1,....,patch_27] for subject 1

all_patch_train_MRI = torch.stack([dataset.tensors[0] for dataset in patch_train_datasets_MRI_PET], dim=1)
all_patch_val_MRI = torch.stack([dataset.tensors[0] for dataset in patch_val_datasets_MRI_PET], dim=1)
all_patch_test_MRI = torch.stack([dataset.tensors[0] for dataset in patch_test_datasets_MRI_PET], dim=1)

all_patch_train_PET = torch.stack([dataset.tensors[1] for dataset in patch_train_datasets_MRI_PET], dim=1)
all_patch_val_PET = torch.stack([dataset.tensors[1] for dataset in patch_val_datasets_MRI_PET], dim=1)
all_patch_test_PET = torch.stack([dataset.tensors[1] for dataset in patch_test_datasets_MRI_PET], dim=1)

all_patch_train_MRI_PET_dataset  = torch.utils.data.TensorDataset(all_patch_train_MRI, all_patch_train_PET, train_y)
all_patch_val_MRI_PET_dataset = torch.utils.data.TensorDataset(all_patch_val_MRI, all_patch_val_PET, val_y)
all_patch_test_MRI_PET_dataset = torch.utils.data.TensorDataset(all_patch_test_MRI, all_patch_test_PET, test_y)

In [18]:
batch_size = 10
all_patch_train_dataloader_MRI_PET = torch.utils.data.DataLoader(all_patch_train_MRI_PET_dataset, 
                                                                batch_size = batch_size, shuffle = True)
all_patch_val_dataloader_MRI_PET = torch.utils.data.DataLoader(all_patch_val_MRI_PET_dataset, 
                                                                batch_size = batch_size, shuffle = True)
all_patch_test_dataloader_MRI_PET = torch.utils.data.DataLoader(all_patch_test_MRI_PET_dataset, 
                                                                batch_size = batch_size, shuffle = False)



In [43]:
# Define our final multimodal attention-based patch network
class multimodal_CNN(nn.Module):
    def __init__(self, MRI_PET_patch_models):
        super().__init__()
        self.MRI_PET_patch_models = nn.ModuleList(MRI_PET_patch_models)
        self.drop= nn.Dropout(p=0.4)
        self.fc3 = nn.Linear(5400, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        patch_outputs = []
        for i in range(27):
            mri_image = x[0][:, i]
            pet_image = x[1][:, i]
            patch_output, _ = self.MRI_PET_patch_models[i]((mri_image, pet_image))
            patch_outputs.append(self.drop(patch_output))
        x = torch.cat(patch_outputs, dim=1)
        x = self.sigmoid(self.fc3(x))
        return x

### Final Multimodal Patch Model with attention

In [36]:
# Reload attention weights for rerunning
MRI_PET_patch_models = []
for i in range(27):
    MRI_PET_patch_model = multimodal_patch_CNN_attention((MRI_models[i],PET_models[i])).to(device)
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP_attention{i}"))                                           
    MRI_PET_patch_models.append(MRI_PET_patch_model)
    
MRI_patch_model = multimodal_CNN(MRI_PET_patch_models).to(device)

# Freeze layers of the patch models (we only want the feature maps)
for patch_model in MRI_patch_model.MRI_PET_patch_models:
    for name, param in patch_model.named_parameters():
        param.requires_grad = False 

# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# SGD Optimizer
optimiser = SGD(MRI_patch_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)


best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(1000):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    MRI_patch_model.train()
    train_model(MRI_patch_model, all_patch_train_dataloader_MRI_PET, optimiser, error, scheduler)
    
    MRI_patch_model.eval()
    avg_val_loss = validate_model(MRI_patch_model, all_patch_val_dataloader_MRI_PET, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(MRI_patch_model.state_dict(), "./trained_models_2/MULTIMODAL_MODEL_MODAL_ATTENTION")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
MRI_patch_model.load_state_dict(torch.load("./trained_models_2/MULTIMODAL_MODEL_MODAL_ATTENTION"))
MRI_patch_model.eval()
evaluate_model(MRI_patch_model, all_patch_test_dataloader_MRI_PET, error)

----------------------------EPOCH 0-------------------------------
Average Training Loss 0.06957613017226828
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.08323001280063536
----------------------------EPOCH 1-------------------------------
Average Training Loss 0.06542665120519575
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.08100275295536692
----------------------------EPOCH 2-------------------------------
Average Training Loss 0.06215867620022571
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.07506222259707568
----------------------------EPOCH 3-------------------------------
Average Training Loss 0.0587429804880111
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.07791784914528452
----------------------------EPOCH 4-------------------------------
Average Training Loss 0.05577169248803717
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.08086491648743792
----------------------------EPO

Average Training Loss 0.02452806300926404
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.06684592438907158
----------------------------EPOCH 43-------------------------------
Average Training Loss 0.02915794249685084
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.0702993557220552
----------------------------EPOCH 44-------------------------------
Average Training Loss 0.026022478938102722
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.09396970199375618
----------------------------EPOCH 45-------------------------------
Average Training Loss 0.02617792951584351
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.06987445165471333
----------------------------EPOCH 46-------------------------------
Average Training Loss 0.024015497370455108
Validation Accuracy: 0.6585365853658537
Average Validation Loss: 0.069707073816439
----------------------------EPOCH 47-------------------------------
Average Training Loss 0.02

### Final Multimodal Patch Model with no attention

In [20]:
# Reload no attention weights for rerunning
MRI_PET_patch_models = []
for i in range(27):
    MRI_PET_patch_model = multimodal_patch_CNN((MRI_models[i],PET_models[i])).to(device)
    MRI_PET_patch_model.load_state_dict(torch.load(f"./patch_temp/multimodal_PATCH_TEMP{i}"))                                           
    MRI_PET_patch_models.append(MRI_PET_patch_model)
    
MRI_patch_model = multimodal_CNN(MRI_PET_patch_models).to(device)

# Freeze layers of the patch models (we only want the feature maps)
for patch_model in MRI_patch_model.MRI_PET_patch_models:
    for name, param in patch_model.named_parameters():
        param.requires_grad = False 

# Set seeds
torch.manual_seed(101)
torch.cuda.manual_seed(101)
random.seed(101)
np.random.seed(101)

# Binary Cross Entropy Loss
error = nn.BCELoss()

# SGD Optimizer
optimiser = SGD(MRI_patch_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.1, total_iters=10)


best_val_loss = np.inf
patience = 20
no_improvement = 0

for epoch in range(1000):
    print(f"----------------------------EPOCH {epoch}-------------------------------")
    MRI_patch_model.train()
    train_model(MRI_patch_model, all_patch_train_dataloader_MRI_PET, optimiser, error, scheduler)
    
    MRI_patch_model.eval()
    avg_val_loss = validate_model(MRI_patch_model, all_patch_val_dataloader_MRI_PET, error)
    
    # Save model weights if improvement seen
    # Otherwise stop model training if there is no improvement in loss after "patience" number of runs
    if avg_val_loss <= best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        torch.save(MRI_patch_model.state_dict(), "./trained_models_2/MULTIMODAL_MODEL_4")
    else:
        no_improvement += 1
        if no_improvement <= patience:
            continue
        else:
            print(f"BEST VAL LOSS: {best_val_loss}")
            break
            
print("----------------------------TEST RESULTS-------------------------------")
MRI_patch_model.load_state_dict(torch.load("./trained_models_2/MULTIMODAL_MODEL_4"))
MRI_patch_model.eval()
evaluate_model(MRI_patch_model, all_patch_test_dataloader_MRI_PET, error)

----------------------------EPOCH 0-------------------------------
Average Training Loss 0.06529779805511725
Validation Accuracy: 0.6341463414634146
Average Validation Loss: 0.0864112319015875
----------------------------EPOCH 1-------------------------------
Average Training Loss 0.055793249521587714
Validation Accuracy: 0.6097560975609756
Average Validation Loss: 0.07096719014935376
----------------------------EPOCH 2-------------------------------
Average Training Loss 0.045468686301200115
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.06271800834958147
----------------------------EPOCH 3-------------------------------
Average Training Loss 0.04139709152037003
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.06574072198169988
----------------------------EPOCH 4-------------------------------
Average Training Loss 0.034925995852615015
Validation Accuracy: 0.6829268292682927
Average Validation Loss: 0.08918767585986997
----------------------------