In [None]:
#▀▀▀▀▀▀▀▀▀▀▀▀ ████████████████████████████████████████████ ▀▀▀▀▀▀▀▀▀▀▀▀
#▀▀▀▀▀▀▀▀▀▀▀▀ ██ ▄▄ ██ ▄▄▄██ ▀██ ██ ▄▄▄██ ▄▄▀█ ▄▄▀██ █████ ▀▀▀▀▀▀▀▀▀▀▀▀
#▀▀▀▀▀▀▀▀▀▀▀▀ ██ █▀▀██ ▄▄▄██ █ █ ██ ▄▄▄██ ▀▀▄█ ▀▀ ██ █████ ▀▀▀▀▀▀▀▀▀▀▀▀
#▀▀▀▀▀▀▀▀▀▀▀▀ ██ ▀▀▄██ ▀▀▀██ ██▄ ██ ▀▀▀██ ██ █ ██ ██ ▀▀ ██ ▀▀▀▀▀▀▀▀▀▀▀▀
#██████████████████████████████████████████████████████████████████████
#█ ▄▄▄ ██ ▄▄▄██ ▄▄ ██ ▄▀▄ ██ ▄▄▄██ ▀██ █▄▄ ▄▄█ ▄▄▀█▄▄ ▄▄██ ▄▄▄ ██ ▀██ █
#█▄▄▄▀▀██ ▄▄▄██ █▀▀██ █ █ ██ ▄▄▄██ █ █ ███ ███ ▀▀ ███ ████ ███ ██ █ █ █
#█ ▀▀▀ ██ ▀▀▀██ ▀▀▄██ ███ ██ ▀▀▀██ ██▄ ███ ███ ██ ███ ████ ▀▀▀ ██ ██▄ █
#██████████████████████████████████████████████████████████████████████

In [None]:
#WARNING - There is not a sanity check for data file naming.
#Please ensure there are matched, identically named pairs of inputs/outputs present in the DATA folders

#====================================================================
#CONFIGURATION
#====================================================================

#Is training of a model to be performed
trainingModel = True

#Is testing of a model to be performed
testingModel = True

#Should a random selection of the data (Ex. 0.1 or 10%) in TRAIN be moved to TEST (default: 0)
#WARNING: Should only ever do this once, if you do not have a TEST dateset manually made available!
testTrainRatio = 0

#How should data be imported for model input 'GRAY' or 'RGB' (default: 'GRAY')
inputMode = 'GRAY'

#Which GPU(s) devices should be used for training; (Default: [-1], any/all available; CPU only: [])
gpus = [-1]

#Should training/validation data be entirely stored on GPU (default: True; improves training/validation performance, set to False if OOM occurs)
storeOnDevice = True

#How many filters should be used at the top layer of the network (default: 64)
numStartFilters = 64

#Which optimizer should be used ('AdamW', 'Adam', 'Nadam', 'SGD' or 'RMSProp')
optimizer = 'AdamW'

#What should the learning rate of the model optimizers be
learningRate = 1e-5

#Beta 1  parameter if applicable to the specified optimizer (default: 0.5)
beta1 = 0.5

#Beta 2  parameter if applicable to the specified optimizer (default: 0.5)
beta2 = 0.999

#What percentage of the training data should be used for training (default: 0.8)
#1.0 or using only one input image will use training loss for early stopping criteria
trainValRatio = 0.8

#How many epochs should a model train for at maximum (default: 10000)
numEpochs = 10000

#How many epochs should the model training wait to see an improvement before terminating (default: 100)
maxPatience = 100

#How many epochs at minimum should be performed before starting to save the current best model and consider termination (default: 10)
minimumEpochs = 10

#What should the resized dimensions during augmentation be (default: (128, 128) for inputMode='GRAY'; (64, 64) for inputMode='RGB')
#Recommend sticking to powers of 2 for GPU efficiency; shouldn't go above size of input images
augSize = 64

#Should the training data be augmented at the end of each epoch (default: True)
augTrainData = True

#Should the validation data be augmented at initialization (default: True)
#Validation data needs to be consistent throughout training for early stopping criteria
#Need to enable if input images are not a consistent size and batchsize_VAL=-1
augValData = True

#Training data batch size (default: 1)
#Personal preference to stick with batch sizes of 1, which will train the slowest, but generally gives the best results
#If batch normalization is introduced in the model, this needs to be at least 16!
batchsize_TRN = 1

#Validation data batch size (default: 1); -1 sets as total length of validation set, which can cause an OOM, depending on input dimensionality
#Higher value here will decrease training time
#For a given network, input dimensionality, and number of validation samples start with 1 and double until just below GPU VRAM cap
#Ex. with default settings, dataset, and 24 GB VRAML: augSize=128 -> batchsize_VAL=256; augSize=64 -> batchsize_VAL=-1 
#Note validation images must be a consistent size (naturally or through augmentation) for any value other than 1 to function
batchsize_VAL = -1

#RNG seed value to control run-to-run consistency, may slow performance, but should be used during development (-1 to disable)
#WARNING - Cannot guarantee consistency between machines/hardware/software; if running benchmarks or ablation study, be as consistent as possible
manualSeedValue = 0

#Should visualizations of the training progression be generated (default: True)
trainingProgressionVisuals = True

#How often (epochs) should visualizations of the training progression be generated (default: 10)
trainingVizSteps = 10

#Input image extension
fileExt = '.png'

#Name for the trained model
modelName = 'Model'

#Should progress bars use ascii formatting (True in jupyter and False in terminal)
asciiFlag = True


In [None]:
#====================================================================
#EXTERNAL
#====================================================================

#ENVIRONMENTAL VARIABLES
#==================================================================

#Setup deterministic behavior for CUDA; may change in future versions...
#"Set a debug environment variable CUBLAS_WORKSPACE_CONFIG to :16:8 (may limit overall performance) 
#or :4096:8 (will increase library footprint in GPU memory by approximately 24MiB)."
import os
if manualSeedValue != -1: os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
    

#IMPORTS
#==================================================================

import copy
import cv2
import datetime
import glob
import math
import matplotlib
import matplotlib.pyplot as plt
import multivolumefile
import natsort
import numpy as np
import pandas as pd
import PIL
import py7zr
import random
import shutil
import sys
import time
import torch
import torch.nn.functional as functional
import torchvision.transforms as transforms
import warnings

from IPython.display import display, HTML
from IPython.core.debugger import set_trace as Tracer
from PIL import Image
from sklearn.metrics import jaccard_score
from torchvision.transforms import v2
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from tqdm.auto import tqdm

#LIBRARY AND WARNINGS SETUP
#==================================================================

#Specifically for notebook; expand cells to fill browser width
display(HTML("<style>.container { width:100% !important; }</style>"))
display(HTML("<style>.output_result { max-width:80% !important; }</style>"))

#Setup deterministic behavior for torch, numpy, and python (these alone do not affect CUDA-specific operations)
if manualSeedValue != -1: 
    torch.use_deterministic_algorithms(True)
    torch.manual_seed(manualSeedValue)
    np.random.seed(manualSeedValue)
    random.seed(manualSeedValue)
        

In [None]:
#====================================================================
#DIRECTORIES
#====================================================================

dir_TrainingData = './DATA/TRAIN/'
dir_TrainingData_Inputs = dir_TrainingData+ 'INPUTS/'
dir_TrainingData_Labels = dir_TrainingData + 'LABELS/'

dir_TestingData = './DATA/TEST/'
dir_TestingData_Inputs = dir_TestingData + 'INPUTS/'
dir_TestingData_Labels = dir_TestingData + 'LABELS/'

dir_Results = './RESULTS/'

dir_TrainingResults = dir_Results + 'TRAIN/'
dir_TrainingResults_ModelProgression = dir_TrainingResults + 'Progression/'
dir_TrainingResults_Model = dir_TrainingResults + modelName

dir_TestingResults = dir_Results + 'TEST/'
dir_TestingResults_Summary = dir_TestingResults + 'Summary/'
dir_TestingResults_Predictions = dir_TestingResults + 'Predictions/'

if os.path.exists(dir_Results): shutil.rmtree(dir_Results)

os.makedirs(dir_TrainingResults)
os.makedirs(dir_TrainingResults_ModelProgression)
os.makedirs(dir_TrainingResults_Model)

os.makedirs(dir_TestingResults)
os.makedirs(dir_TestingResults_Summary)
os.makedirs(dir_TestingResults_Predictions)


In [None]:
#====================================================================
#COMPUTE
#====================================================================

#Note GPUs available/specified
#If multiple/parallel GPU acceleration is needed, then adopt DDP using Ray
if not torch.cuda.is_available(): gpus = []
numGPUs = len(gpus)
if (numGPUs > 0) and (gpus[0] == -1): gpus = [*range(torch.cuda.device_count())]


In [None]:
#====================================================================
#UTILITY CLASSES/METHODS
#==================================================================

#Visualize/save an image without borders/axes
def visualizeBorderless(image, saveLocation, cmap='gray', vmin=None, vmax=None):
    if type(cmap) == str: cmap = plt.get_cmap(cmap)
    if vmin==None: vmin=np.nanmin(image)
    if vmax==None: vmax=np.nanmax(image)
    Image.fromarray(np.uint8(cmap(((np.clip(image, vmin, vmax)-vmin)/(vmax-vmin)))*255)).save(saveLocation)

#Visualize/save a simple data plot
def basicPlot(xData, yData, saveLocation, xLabel='', yLabel=''):
    font = {'size' : 18}
    plt.rc('font', **font)
    f = plt.figure(figsize=(20,8))
    ax1 = f.add_subplot(1,1,1)    
    ax1.plot(xData, yData, color='black')
    ax1.set_xlabel(xLabel)
    ax1.set_ylabel(yLabel)
    plt.savefig(saveLocation)
    plt.close()


In [None]:
#====================================================================
#NETWORK CLASSES/METHODS
#====================================================================

#Slighly modified, but still quite basic U-Net architecture
#Uses leaky relu activations throughout to better propogate loss during backpropogation
#Upsampling swaps in nearest-neightbor resizing in place of convolutional transposition to remove checkerboard artifacts
#Augmentation is performed on trainingdata after every epoch to maximize data variance, consider adding additional transforms
#Bias disabled for efficiency; shouldn't make a noticable difference here

#Downsampling convolutional block
class Conv_Dn(nn.Module):
    def __init__(self, numIn, numOut):
        super().__init__()
        self.act = nn.LeakyReLU(0.2, inplace=True) 
        self.conv0 = nn.Conv2d(in_channels=numIn, out_channels=numOut, kernel_size=3, stride=1, padding='same', bias=False)
        self.conv1 = nn.Conv2d(in_channels=numOut, out_channels=numOut, kernel_size=3, stride=1, padding='same', bias=False)
        nn.init.normal_(self.conv0.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.conv1.weight, mean=0.0, std=0.02)

    def forward(self, data):
        data = self.act(self.conv0(data))
        return self.act(self.conv1(data))

#Upsampling convolutional block
class Conv_Up(nn.Module):
    def __init__(self, numIn, numOut):
        super().__init__()
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.conv0 = nn.Conv2d(in_channels=numIn, out_channels=numIn, kernel_size=3, stride=1, padding='same', bias=False)
        self.conv1 = nn.Conv2d(in_channels=numIn+numOut, out_channels=numOut, kernel_size=3, stride=1, padding='same', bias=False)
        nn.init.normal_(self.conv0.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.conv1.weight, mean=0.0, std=0.02)

    def forward(self, data, skip):
        data = self.act(self.conv0(functional.interpolate(data, size=skip.size()[2:], mode='nearest')))
        return self.act(self.conv1(torch.cat([data, skip], 1)))
    
class Model(nn.Module):
    def __init__(self, numFilt, numChan):
        super().__init__()
        
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.convDn0 = Conv_Dn(numChan, numFilt)
        self.convDn1 = Conv_Dn(numFilt, numFilt*2)
        self.convDn2 = Conv_Dn(numFilt*2, numFilt*4)
        self.convDn3 = Conv_Dn(numFilt*4, numFilt*8)
        self.convDn4 = Conv_Dn(numFilt*8, numFilt*16)
        
        self.convUp3 = Conv_Up(numFilt*16, numFilt*8)
        self.convUp2 = Conv_Up(numFilt*8, numFilt*4)
        self.convUp1 = Conv_Up(numFilt*4, numFilt*2)
        self.convUp0 = Conv_Up(numFilt*2, numFilt)
        
        self.convOut = nn.Conv2d(in_channels=numFilt, out_channels=1, kernel_size=3, stride=1, padding='same', bias=False)
        nn.init.normal_(self.convOut.weight, mean=0.0, std=0.02)
        
    def forward(self, data):
        convDn0 = self.convDn0(data)
        convDn1 = self.convDn1(self.pool(convDn0))
        convDn2 = self.convDn2(self.pool(convDn1))
        convDn3 = self.convDn3(self.pool(convDn2))
        convDn4 = self.convDn4(self.pool(convDn3))
        
        convUp3 = self.convUp3(convDn4, convDn3)
        convUp2 = self.convUp2(convDn3, convDn2)
        convUp1 = self.convUp1(convDn2, convDn1)
        convUp0 = self.convUp0(convDn1, convDn0)
        return self.convOut(convUp0)
    
#Perform augmentation and setup for DLADS data processing
class DataPreprocessing(Dataset):
    def __init__(self, inputs, labels, device, augmentFlag, trainDataFlag):
        super().__init__()
        self.noAugmentFlag = not augmentFlag
        self.trainDataFlag = trainDataFlag
        
        if storeOnDevice:
            self.data_Inputs = [torch.from_numpy(item).float().to(device) for item in inputs]
            self.data_Labels = [torch.from_numpy(item).float().to(device) for item in labels]
        else: 
            self.data_Inputs = [torch.from_numpy(item).float() for item in inputs]
            self.data_Labels = [torch.from_numpy(item).float() for item in labels]
        self.channelSplit = [self.data_Inputs[0].size()[0]]
        
        if augmentFlag: 
            self.transform = transforms.Compose([
                v2.RandomResizedCrop(size=(augSize, augSize), scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0), antialias=True), #https://arxiv.org/pdf/1409.4842.pdf
                v2.RandomHorizontalFlip(p=0.5),
                v2.RandomVerticalFlip(p=0.5)
            ])
            
        #Validation data should only be augmented at initialization for stable execution of early stopping criteria
        if augmentFlag and not trainDataFlag:
            for index in range(0, len(self.data_Inputs)):
                self.data_Inputs[index], self.data_Labels[index] = torch.tensor_split(self.transform(torch.cat([self.data_Inputs[index], self.data_Labels[index]], 0)), self.channelSplit, 0)
    
    def __getitem__(self, index):
        if self.noAugmentFlag or not self.trainDataFlag: return self.data_Inputs[index], self.data_Labels[index]
        return torch.tensor_split(self.transform(torch.cat([self.data_Inputs[index], self.data_Labels[index]], 0)), self.channelSplit, 0)
               
    def __len__(self):
        return len(self.data_Inputs)

#Define GeneralSegmentation network
class GeneralSegmentation:
    def __init__(self, trainFlag, local_gpus):
    
        #Create model
        if inputMode == 'GRAY': self.model = Model(numStartFilters, 1)
        elif inputMode == 'RGB': self.model = Model(numStartFilters, 3)
    
        #If not training, load parameters (before potential parallelization on multiple GPUs) and setup for inferencing
        if not trainFlag: 
            with multivolumefile.open(dir_TrainingResults_Model + os.path.sep + modelName + '.7z', mode='rb') as modelArchive:
                with py7zr.SevenZipFile(modelArchive, 'r') as archive:
                    archive.extract(dir_TrainingResults)
            _ = self.model.load_state_dict(torch.load(dir_TrainingResults_Model + '.pt'))
            _ = self.model.train(False)
            os.remove(dir_TrainingResults_Model + '.pt')
            
        #Configure CPU/GPU computation environment
        self.device = torch.device(f"cuda:{local_gpus[0]}" if len(local_gpus) > 0 else "cpu")
        self.model.to(self.device)
        
        #If training, setup optimizers, load the data, and perform training
        if trainFlag:
            if optimizer == 'AdamW': self.opt = optim.AdamW(self.model.parameters(), lr=learningRate, betas=(beta1, beta2))
            elif optimizer == 'Adam': self.opt = optim.Adam(self.model.parameters(), lr=learningRate, betas=(beta1, beta2))
            elif optimizer == 'Nadam': self.opt = optim.NAdam(self.model.parameters(), lr=learningRate, betas=(beta1, beta2))
            elif optimizer == 'SGD': self.opt = optim.SGD(self.model.parameters(), lr=learningRate)
            elif optimizer == 'RMSProp': self.opt = optim.RMSprop(self.model.parameters(), lr=learningRate)
            self.loadData()
            self.train()
            
    def loadData(self):
    
        #Accumulate input files; sorting by name to ensure consistant order behavior
        filenames_Inputs = natsort.natsorted(glob.glob(dir_TrainingData_Inputs+'*'+fileExt), reverse=False)
        filenames_Labels = natsort.natsorted(glob.glob(dir_TrainingData_Labels+'*'+fileExt), reverse=False)
        numInputs = len(filenames_Inputs)
        if numInputs != len(filenames_Labels): sys.exit('Error - The number of inputData and label files for training do not match')
        
        #Find index to split data into training/validation sets
        trainValSplit = int(trainValRatio*numInputs)
    
        #If there is not going to be a validation set then indicate such, otherwise setup needed variables
        if trainValSplit == numInputs:
            self.valFlag = False
        else:
            self.valFlag = True
            vizSampleIndices = [numInputs-14, numInputs-4]
            self.numViz = len(vizSampleIndices)
            inputs_VAL, labels_VAL = [], []
            self.inputs_VIZ_NET, self.inputs_VIZ, self.labels_VIZ = [], [], []
        
        #Extract and prepare data
        inputs_TRN, labels_TRN = [], []
        for index in tqdm(range(0, numInputs), desc = 'Loading TRN Data', leave=True, ascii=asciiFlag):
            
            #Load inputs and labels, arranging dimensions to be [C, H, W]
            if inputMode=='GRAY': inputData = np.expand_dims(cv2.cvtColor(cv2.imread(filenames_Inputs[index]), cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0, 0)
            elif inputMode=='RGB': inputData = np.moveaxis(cv2.cvtColor(cv2.imread(filenames_Inputs[index]), cv2.COLOR_BGR2RGB).astype(np.float32)/255.0, -1, 0)
            label = cv2.cvtColor(cv2.imread(filenames_Labels[index]), cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
            
            #Handle any validation data intended to be shown in model progression images
            if self.valFlag and (index in vizSampleIndices):
                
                #Setup input data for processing by the network, arranging dimensions to be [B=1, C, H, W]
                if storeOnDevice: self.inputs_VIZ_NET.append(torch.from_numpy(np.expand_dims(inputData, 0)).float().to(self.device))
                else: self.inputs_VIZ_NET.append(torch.from_numpy(np.expand_dims(inputData, 0)).float())
                
                #Setup input data for visualization, arranging dimension to be [H, W, C (if applicable)]
                if inputMode=='GRAY': self.inputs_VIZ.append(inputData[0])
                elif inputMode=='RGB': self.inputs_VIZ.append(np.moveaxis(inputData, 0, -1))
                self.labels_VIZ.append(label)
            
            #Store data into the appropriate dataset; training or validation
            label = np.expand_dims(label, 0)
            if index <= trainValSplit:
                inputs_TRN.append(inputData)
                labels_TRN.append(label)
            else:
                inputs_VAL.append(inputData)
                labels_VAL.append(label)
            
        #Setup data handler for training dataset
        data_TRN = DataPreprocessing(inputs_TRN, labels_TRN, self.device, augTrainData, True)
        self.dataloader_TRN = DataLoader(data_TRN, batch_size=batchsize_TRN, num_workers=0, shuffle=True)
        self.numTRN = len(self.dataloader_TRN)
        
        #Setup data handler for validation dataset
        if self.valFlag:
            data_VAL = DataPreprocessing(inputs_VAL, labels_VAL, self.device, augValData, False)
            if batchsize_VAL == -1: self.dataloader_VAL = DataLoader(data_VAL, batch_size=len(inputs_VAL), num_workers=0, shuffle=False)
            else: self.dataloader_VAL = DataLoader(data_VAL, batch_size=batchsize_VAL, num_workers=0, shuffle=False)
            self.numVAL = len(self.dataloader_VAL)
                
    #Produce/save visualizations after a training epoch
    def visualizeTraining(self, epoch):
        
        #Setup blank canvas
        if self.valFlag: f = plt.figure(figsize=(24,15))
        else: f = plt.figure(figsize=(24,5))
        f.subplots_adjust(top = 0.90)
        f.subplots_adjust(wspace=0.2, hspace=0.2)
        
        #Visualize training/validation loss plots
        if self.valFlag: ax = plt.subplot2grid((3,1), (0,0))
        else: ax = plt.subplot2grid((1,1), (0,0))
        ax.plot(self.loss_Trn, label='Training')
        if self.valFlag: ax.plot(self.loss_Val, label='Validation')
        ax.legend(loc='upper right', fontsize=14)
        ax.set_yscale('log')
        
        #Process and visualize each of the validation samples intended for illustration across training progression
        if self.valFlag: 
            for vizSampleNum in range(0, self.numViz): 
            
                if not storeOnDevice: inputData = self.inputs_VIZ_NET[vizSampleNum].to(self.device)
                else: inputData = self.inputs_VIZ_NET[vizSampleNum]
                input_VIZ = self.inputs_VIZ[vizSampleNum]
                label_REAL = self.labels_VIZ[vizSampleNum]
                
                label_PRED, label_PRED_Processed = self.inference(inputData, False)
                score_Jaccard = jaccard_score(label_REAL.astype(int), label_PRED_Processed, pos_label=1, average='micro', zero_division=1.0)
                
                ax = plt.subplot2grid((3,4), (vizSampleNum+1,0))
                im = ax.imshow(input_VIZ, aspect='auto', cmap='gray', vmin=0, vmax=1, interpolation='none')
                ax.set_title('Input', fontsize=15, fontweight='bold')
                cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
                cbar.formatter.set_powerlimits((0, 0))
                
                ax = plt.subplot2grid((3,4), (vizSampleNum+1,1))
                im = ax.imshow(label_REAL, aspect='auto', cmap='gray', vmin=0, vmax=1, interpolation='none')
                ax.set_title('Ground-Truth', fontsize=15, fontweight='bold')
                cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
                cbar.formatter.set_powerlimits((0, 0))
                
                ax = plt.subplot2grid((3,4), (vizSampleNum+1,2))
                im = ax.imshow(label_PRED, aspect='auto', cmap='gray', interpolation='none')
                plotTitle = 'PRED'
                ax.set_title(plotTitle, fontsize=15, fontweight='bold')
                cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
                cbar.formatter.set_powerlimits((0, 0))
                
                ax = plt.subplot2grid((3,4), (vizSampleNum+1,3))
                im = ax.imshow(label_PRED_Processed, aspect='auto', cmap='gray', vmin=0, vmax=1, interpolation='none')
                plotTitle = 'PRED>=0.5 - Jaccard: ' + '{:.6f}'.format(round(score_Jaccard, 6))
                ax.set_title(plotTitle, fontsize=15, fontweight='bold')
                cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
                cbar.formatter.set_powerlimits((0, 0))
            
            plotTitle = 'Epoch: '+str(epoch)+'     Patience: '+str(self.patience)+'/'+str(maxPatience)
            plotTitle += '\nBest Loss: '+ '{:.6f}'.format(round(self.bestLoss, 6)) +' at Epoch: '+str(self.bestEpoch)
            plotTitle += '\nLoss - TRN: ' + '{:.6f}'.format(round(self.loss_Trn[-1], 6))
            plotTitle += '     VAL: ' + '{:.6f}'.format(round(self.loss_Val[-1], 6))
            plt.suptitle(plotTitle, fontsize=20, fontweight='bold')
            
            #Save resulting plot
            f.savefig(dir_TrainingResults_ModelProgression + 'epoch_' +str(epoch) + '.tiff', bbox_inches='tight')
            plt.close(f)
        
    def computeLoss(self, data, label, trainFlag=False):
        
        #Zero network gradients
        if trainFlag: self.model.zero_grad()
        
        #Compute MAE
        if not storeOnDevice: data, label = data.to(self.device), label.to(self.device)
        loss = torch.mean(torch.abs(self.model(data)-label))
        
        #Compute loss gradients and update network parameters
        if trainFlag:
            loss.backward()
            self.opt.step()
        
        return loss.item()
        
    def train(self):
        
        #Setup storage for losses
        self.loss_Trn, self.loss_Val = [], []

        #Setup variables for early stopping critera
        bestModel, self.bestLoss, self.bestEpoch, self.patience, endTraining = None, np.inf, -1, 0, False

        #Create progress bar
        trainingBar = tqdm(range(numEpochs), desc="Epochs", leave=True, ascii=asciiFlag)
        
        #Perform model training
        t0 = time.time()
        for epoch in trainingBar:
            
            #Compute losses over the training dataset
            _ = self.model.train(True)
            self.loss_Trn.append(np.mean([self.computeLoss(data, label, True) for data, label in tqdm(self.dataloader_TRN, total=self.numTRN, desc='TRN Batches', leave=False, ascii=asciiFlag)]))
            
            #Compute losses over the validation dataset
            _ = self.model.train(False)
            if self.valFlag: 
                with torch.inference_mode(): 
                    self.loss_Val.append(np.mean([self.computeLoss(data, label, False) for data, label in tqdm(self.dataloader_VAL, total=self.numVAL, desc='VAL Batches', leave=False, ascii=asciiFlag)]))
            
            #If applicable: update best model parameters or increase patience
            if (epoch >= minimumEpochs):
                
                if self.valFlag: currLoss = self.loss_Val[-1]
                else: currLoss = self.loss_Trn[-1]
                
                if (currLoss <= self.bestLoss): bestModel, self.bestLoss, self.bestEpoch, self.patience = copy.deepcopy(self.model.state_dict()), currLoss, epoch, 0
                else: self.patience += 1
            
            #Update progress bar with epoch data
            progBarString = "PAT: " + str(self.patience) + "/" + str(maxPatience)
            progBarString += ", LOSS -" 
            progBarString += " TRN: " + '{:.6f}'.format(round(self.loss_Trn[-1], 6))
            if self.valFlag: progBarString += ", VAL: " + '{:.6f}'.format(round(self.loss_Val[-1], 6))
            trainingBar.set_postfix_str(progBarString)
            trainingBar.refresh()
            
            #Exit training if early stopping criteria is triggered
            if self.patience >= maxPatience: endTraining = True
            
            #Perform visualization(s) if applicable
            if trainingProgressionVisuals and ((epoch == 0) or (epoch % trainingVizSteps == 0) or endTraining or (self.bestEpoch == epoch)): self.visualizeTraining(epoch)
        
            #If training should be terminated, exit the loop
            if endTraining: break
        
        t1 = time.time()
        trainingTime = datetime.timedelta(seconds=(t1-t0))
        
        lines = ['Model Training Time: ' + str(trainingTime)]
        with open(dir_TrainingResults + 'trainingTime.txt', 'w') as f:
            for line in lines: _ = f.write(line+'\n')
        print(lines[0])
        
        #Strip out any parallel 'module' references from the model definition
        bestModel = {key.replace("module.", ""): value for key, value in bestModel.items()}
        
        #Store the model across multiple 100 Mb files to bypass Github file size limits
        torch.save(bestModel, dir_TrainingResults_Model + '.pt')
        if os.path.exists(dir_TrainingResults_Model): shutil.rmtree(dir_TrainingResults_Model)
        os.makedirs(dir_TrainingResults_Model)
        with multivolumefile.open(dir_TrainingResults_Model + os.path.sep + modelName + '.7z', mode='wb', volume=104857600) as modelArchive:
            with py7zr.SevenZipFile(modelArchive, 'w') as archive:
                archive.writeall(dir_TrainingResults_Model + '.pt', modelName + '.pt')
        os.remove(dir_TrainingResults_Model + '.pt')
        
        #Save training history
        history = np.vstack([np.array(range(0, epoch+1)), self.loss_Trn])
        if self.valFlag: history = np.vstack([history, self.loss_Val])
        pd.DataFrame(history.T, columns=['Epoch','Loss_TRN', 'Loss_VAL']).to_csv(dir_TrainingResults+'trainingHistory.csv', index=False)
    
    def inference(self, inputData, transferFlag):
        if transferFlag: inputData = inputData.to(self.device)
        with torch.inference_mode(): label_PRED = self.model(inputData).detach().cpu().numpy()[0, 0]
        label_PRED_Processed = (label_PRED>=0.5).astype(int)
        return label_PRED, label_PRED_Processed

In [None]:
#====================================================================
#MAIN PROGRAM
#====================================================================

#If a TEST set has to be split out from the inputs stored in TRAIN, do so before importing and training a model
if testTrainRatio > 0:
    
    #Verify the TEST directories are empty before proceeding
    if len(glob.glob(dir_TestingData_Inputs+'*'+fileExt)) > 0 or len(glob.glob(dir_TestingData_Labels+'*'+fileExt)) > 0: 
        print("\nWarning - TEST directory already contains files. Either disable testTrainSplit, which should only ever be run once, or clear the TEST/INPUTS and TEST/LABELS directories. Proceeding under the assumption that testing should use only the files that are already present in the DATA/TEST/ directories.\n")
    else:
        #Reset deterministic behavior for torch, numpy, and python (these alone do not affect CUDA-specific operations)
        if manualSeedValue != -1: 
            torch.use_deterministic_algorithms(True)
            torch.manual_seed(manualSeedValue)
            np.random.seed(manualSeedValue)
            random.seed(manualSeedValue)

        #Accumulate input files; sorting by name to ensure consistant order behavior
        filenames_Inputs = natsort.natsorted(glob.glob(dir_TrainingData_Inputs+'*'+fileExt), reverse=False)
        filenames_Labels = natsort.natsorted(glob.glob(dir_TrainingData_Labels+'*'+fileExt), reverse=False)
        numInputs = len(filenames_Inputs)

        #Find index to split data into training/validation sets
        testTrainSplit = int(testTrainRatio*numInputs)

        #Move random selection of files to the TEST directory
        filenames = list(zip(filenames_Inputs, filenames_Labels))
        random.shuffle(filenames)
        filenames_Inputs, filenames_Labels = zip(*filenames)
        filenames_Inputs, filenames_Labels = filenames_Inputs[-testTrainSplit:], filenames_Labels[-testTrainSplit:]
        _ = [shutil.move(filename, dir_TestingData_Inputs+os.path.basename(filename)) for filename in filenames_Inputs]
        _ = [shutil.move(filename, dir_TestingData_Labels+os.path.basename(filename)) for filename in filenames_Labels]

#Train a new model
if trainingModel: 
    
    #Reset deterministic behavior for torch, numpy, and python (these alone do not affect CUDA-specific operations)
    if manualSeedValue != -1: 
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(manualSeedValue)
        np.random.seed(manualSeedValue)
        random.seed(manualSeedValue)
        
    #Create a new model; automatically starts training 
    network = GeneralSegmentation(True, gpus)

#Test a trained model
if testingModel: 
    
    #Reset deterministic behavior for torch, numpy, and python (these alone do not affect CUDA-specific operations)
    if manualSeedValue != -1: 
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(manualSeedValue)
        np.random.seed(manualSeedValue)
        random.seed(manualSeedValue)
        
    #Load the existing model 
    network = GeneralSegmentation(False, gpus)
    
    #Accumulate input files; sorting by name to ensure consistant order behavior
    filenames_Inputs = natsort.natsorted(glob.glob(dir_TestingData_Inputs+'*'+fileExt), reverse=False)
    filenames_Labels = natsort.natsorted(glob.glob(dir_TestingData_Labels+'*'+fileExt), reverse=False)
    numInputs = len(filenames_Inputs)
    if numInputs != len(filenames_Labels): sys.exit('Error - The number of inputData and label files for training do not match')
    
    #Load in the testing dataset
    inputs_TST_NET, inputs_TST, labels_TST, inputs_Names = [], [], [], []
    for index in tqdm(range(0, numInputs), desc = 'Loading TST Data', leave=True, ascii=asciiFlag):
        inputs_Names.append(os.path.splitext(os.path.basename(filenames_Inputs[index]))[0])
        if inputMode=='GRAY': inputData = np.expand_dims(cv2.cvtColor(cv2.imread(filenames_Inputs[index]), cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0, 0)
        elif inputMode=='RGB': inputData = np.moveaxis(cv2.cvtColor(cv2.imread(filenames_Inputs[index]), cv2.COLOR_BGR2RGB).astype(np.float32)/255.0, -1, 0)
        label = cv2.cvtColor(cv2.imread(filenames_Labels[index]), cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
        inputs_TST_NET.append(torch.from_numpy(np.expand_dims(inputData, 0)).float())
        labels_TST.append(label)
        
        #Setup input data for visualization, arranging dimension to be [H, W, C (if applicable)]
        if inputMode=='GRAY': inputs_TST.append(inputData[0])
        elif inputMode=='RGB': inputs_TST.append(np.moveaxis(inputData, 0, -1))
    
    #Inference, evaluate, and visualize the testing dataset/results
    scores_Jaccard = []
    for index in tqdm(range(0, numInputs), desc = 'Testing', leave=True, ascii=asciiFlag):

        input_VIZ = inputs_TST[index]
        label_REAL = labels_TST[index]
        label_PRED, label_PRED_Processed = network.inference(inputs_TST_NET[index], True)
        score_Jaccard = jaccard_score(label_REAL.astype(int), label_PRED_Processed, pos_label=1, average='micro', zero_division=1.0)
        scores_Jaccard.append(score_Jaccard)

        #Setup blank canvas
        f = plt.figure(figsize=(12,2.5))
        f.subplots_adjust(top = 0.75)
        f.subplots_adjust(wspace=0.3, hspace=0.3)

        ax = plt.subplot2grid((1,4), (0,0))
        im = ax.imshow(input_VIZ, aspect='auto', cmap='gray', vmin=0, vmax=1, interpolation='none')
        ax.set_title('Input', fontsize=10, fontweight='bold')
        cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
        cbar.formatter.set_powerlimits((0, 0))

        ax = plt.subplot2grid((1,4), (0,1))
        im = ax.imshow(label_REAL, aspect='auto', cmap='gray', vmin=0, vmax=1, interpolation='none')
        ax.set_title('Ground-Truth', fontsize=10, fontweight='bold')
        cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
        cbar.formatter.set_powerlimits((0, 0))

        ax = plt.subplot2grid((1,4), (0,2))
        im = ax.imshow(label_PRED, aspect='auto', cmap='gray', interpolation='none')
        plotTitle = 'PRED'
        ax.set_title(plotTitle, fontsize=10, fontweight='bold')
        cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
        cbar.formatter.set_powerlimits((0, 0))

        ax = plt.subplot2grid((1,4), (0,3))
        im = ax.imshow(label_PRED_Processed, aspect='auto', cmap='gray', vmin=0, vmax=1, interpolation='none')
        plotTitle = 'PRED>=0.5'
        ax.set_title(plotTitle, fontsize=10, fontweight='bold')
        cbar = f.colorbar(im, ax=ax, orientation='vertical', pad=0.01)
        cbar.formatter.set_powerlimits((0, 0))

        plotTitle = inputs_Names[index]
        plotTitle += '\nJaccard Score: ' + '{:.6f}'.format(round(score_Jaccard, 6))
        plt.suptitle(plotTitle, fontsize=10, fontweight='bold')

        f.savefig(dir_TestingResults_Summary + inputs_Names[index] + '_summary.tiff', bbox_inches='tight')
        plt.close(f)

        visualizeBorderless(label_PRED_Processed, dir_TestingResults_Predictions + inputs_Names[index] + '_prediction.tiff', cmap='gray', vmin=0, vmax=1)

    lines = ['Jaccard Score: ' + str(np.mean(scores_Jaccard)) + ' +/- ' + str(np.std(scores_Jaccard))]
    with open(dir_TestingResults + 'dataPrintout.csv', 'w') as f:
        for line in lines: 
            _ = f.write(line+'\n')
            print(line)

