## Import Modules and Check GPU

In [1]:
# import utility libraries
import os
import numpy as np
import gc
from datetime import datetime
from pytz import timezone
import pytz
import random

from processing_functions import *

# import deep learning libraries
from torchvision import transforms, utils
import tensorflow as tf
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from ignite.engine import Engine, Events
from ignite.metrics import Metric, Loss
import pprint


# from monai.losses import DiceLoss
from monai.losses import DiceLoss, FocalLoss, DiceFocalLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import BasicUNet, UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)

# import processing libraries
from patchify import unpatchify
import skimage

In [2]:
torch.cuda.empty_cache()

In [3]:
# # cloud server
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# macbook
# # use_mps = torch.has_mps
# use_mps = False
# device = torch.device("mps" if use_mps else "cpu")

In [4]:
use_cuda

True

## Read TIF images into Tensors

In [5]:
import os
import torchvision
from torch.utils.data import Dataset
import tifffile
import glob

# Cloud Server
raw_path = "/home/jovyan/workspace/Images/Raw/*.tif"
mask_path = "/home/jovyan/workspace/Images/Mask/*.tif"

# # Jason's Desktop
# raw_path = "I:\My Drive\Raw\*.tif"
# mask_path = "I:\My Drive\Mask\*.tif"

# # Jason's Macbook
# raw_path = "/Users/jasonfung/haaslabdataimages@gmail.com - Google Drive/My Drive/Raw/*.tif"
# mask_path = "/Users/jasonfung/haaslabdataimages@gmail.com - Google Drive/My Drive/Mask/*.tif"

raw_filename_list = glob.glob(raw_path) 
mask_filename_list = glob.glob(mask_path)

# Pre Shuffle
raw_filename_list.sort()
mask_filename_list.sort()

# Shuffle the filename list
from sklearn.utils import shuffle
raw_filename_list, mask_filename_list = shuffle(raw_filename_list, mask_filename_list, random_state = 42)

In [6]:
raw_filename_list

['/home/jovyan/workspace/Images/Raw/000_B_181107_A_N1B2_4a61736f.tif',
 '/home/jovyan/workspace/Images/Raw/000_D_180906_A_N1C1_53696168.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180613_N4_4a61736f.tif',
 '/home/jovyan/workspace/Images/Raw/00_ML_20180614_N5_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/00_ML_20180614_N3_53696168.tif',
 '/home/jovyan/workspace/Images/Raw/00_GFPC_20180615_N1_50616967.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180622_N2_50616967.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20190604_A_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_D_180907_A_N1B3_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/060_B_181031_A_N1B3_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_GFP_181027_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180622_N1_53696168.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20190604_B_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/00_ML_20180614_N1_50616967.tif',
 '/home/jovyan/workspace

In [7]:
len(raw_filename_list)

16

## Processing Classes

## Patching and Reconstruction Functions

Input of an image is (# z stacks, # x pixels, # y pixels). Then it is split into sub volumes of size (z patch size, x patch size, y patch size) where each subvolume is (voxelized), meaning it has a coordinate relative to the original image. This ends up becoming a (z location, x location, y location, z patch size, x patch size, y patch size).

If the image has a depth that cannot be evenly split by 2^n depth patch size where n is an integer, the function patch_images() will return an "upper half" and a "lower half" of the image. For example, given an image size of (77,512,512) and a z patch size of 16, x patch size of 32, and y patch of 32, the image cannot be evenly split into 16ths, since 77 % 16 = 4 remainder 13. Therefore, starting from the top of the stack to slice 64: the upper half becomes (64, 512, 512). Contrastly, starting from the bottom of the stack at slice 77 to slice 61, the lower half becomes (16, 512, 512). The two split volumes will become merged in the end.

To prepare for training, the coordinated subvolumes are reshaped into "batch form". From the previous example with the upper half (64, 512, 512) split and coordinated into a 6D array (4, 16, 16, 16, 32, 32) -> (4 * 16 * 16, 16, 32, 32).


## Split Training and Testing Images

In [8]:
# define patching parameters
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
split_size = 0.8
dim_order = (0,4,1,2,3) # define the image and mask dimension order

patch_transform = transforms.Compose([
#                                       new_shape(new_xy = (600,960)),
                                      MinMaxScalerVectorized(),
                                      patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = False)])

# define transforms for labeled masks
label_transforms = transforms.Compose([
#                                        new_shape(new_xy = (600,960)),
                                       process_masks(int_class = 3),
                                       patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = True)])


raw_training_list, mask_training_list = raw_filename_list[:int(split_size*len(raw_filename_list))], mask_filename_list[:int(split_size*len(mask_filename_list))]
raw_testing_list, mask_testing_list = raw_filename_list[int(split_size*len(raw_filename_list)):], mask_filename_list[int(split_size*len(mask_filename_list)):]
print(len(raw_training_list))

training_data = MyImageDataset(raw_training_list,
                               mask_training_list,
                               transform = patch_transform,
                               label_transform = label_transforms,
                               device = device,
                               img_order = dim_order,
                               mask_order = dim_order,
                               num_classes = 4,
                               train=True)

testing_data = MyImageDataset(raw_testing_list,
                              mask_testing_list,
                              transform = patch_transform,
                              label_transform = label_transforms,
                              device = device,
                              img_order = dim_order,
                              mask_order = dim_order,
                              num_classes = 4,
                              train=False)

from torch.utils.data import DataLoader
training_dataloader = DataLoader(training_data, batch_size = 1, shuffle = False)
testing_dataloader = DataLoader(testing_data, batch_size = 1, shuffle = False)

12


In [9]:
upper, upper_shape, lower, lower_shape, full_mask, mask_upper, mask_lower = next(iter(training_dataloader))

In [10]:
len(upper)

128

# Define Model and Parameters

### Model: ResUNet

In [11]:
# set up loss and optimizer
max_epochs = 150
dropout = 0.2
learning_rate = 5e-5
# decay = 1e-5
input_chnl = 1
output_chnl = 4

model = UNet(spatial_dims=3, 
             in_channels = input_chnl,
             out_channels = output_chnl,
             channels = (16, 32, 64, 128, 256),
             strides=(2, 2, 2, 2),
             num_res_units=2,
             norm = "batch",
             dropout = dropout)

model = model.to(device)

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))

4810977


## Loss, Metric, Schedulers

In [13]:
# loss_function = FocalLoss()
loss_function = DiceCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epochs/30, verbose = True)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.5, patience = 10, threshold=1e-5, threshold_mode= 'abs', verbose=True)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
discretize = Compose([Activations(softmax = True), 
                      AsDiscrete(logit_thresh=0.5)])

Adjusting learning rate of group 0 to 5.0000e-05.


## Augmentation

In [14]:
# # Augmentation Function parameters
# degree = (25, 5, 5)
# translate = (10,10,10)
# transform_rotate = torchio.RandomAffine(degrees=degree, 
#                                         translation=translate, 
#                                         image_interpolation="bspline")

# transform_flip = torchio.RandomFlip(axes=('ap',))
# all_transform = torchio.Compose([transform_rotate,
#                                  transform_flip])

# Augmentation Function parameters using resnet parameters
degree = (25, 0, 0)
# translate = (10,10,10)
transform_rotate = torchio.RandomAffine(degrees=degree, 
#                                         translation=translate, 
                                        image_interpolation="bspline")
transform_flip = torchio.RandomFlip(axes=('ap',))
all_transform = torchio.Compose([transform_rotate,
                                 transform_flip])


## Define Training and Validation Functions

In [15]:
def train(engine, batch):

    batch_size = 64
    lateral_steps = 64
    axial_steps = 16
    patch_size = (axial_steps, lateral_steps, lateral_steps)    
    augment = True
    shuffle = True

    model.train()
    running_loss = 0
    count_loss = 0
    # Instantiate the dice sum for each class
    
    upper_img, upper_shape, lower_img, lower_shape, full_mask, upper_mask, lower_mask = batch
    # upper_img: dict() of {"index 1": upper_raw_tensor_1, "index 2": upper_raw_tensor_2, ..., "index n": upper_raw_tensor_n} raw_tensor = FloatTensor(Z,Y,X)
    # lower_img: dict() of {"index 1": lower_raw_tensor_1, "index 2": lower_raw_tensor_2", ..., "index n": lower_raw_tensor_n} raw_tensor = FloatTensor(Z,Y,X)
    # upper_shape: tuple() representing shape of the upper volume (z,y,x) Note: this is used for reconstruction
    # lower_shape: tuple() representing shape of the lower volume (z,y,x) Note: this is used for reconstruction
    # full_mask: torch.FloatTensor of size (B,C,Z,Y,X) B = batch, C = Class Channel 
    # upper_mask: dict() of {"index 1": upper_mask_tensor_1, "index 2": upper_mask_tensor_2, ..., "index n": upper_mask_tensor_n} mask_tensor = FloatTensor(C,Z,Y,X)
    # lower_mask: dict() of {"index 1": upper_mask_tensor_1, "index 2": upper_mask_tensor_2, ..., "index n": lower_mask_tensor_n} mask_tensor = FloatTensor(C,Z,Y,X)
    # Empty list to place subvolumes in
    
    tmp_upper_dict = {}
    tmp_lower_dict = {}    
    
    if shuffle == True:
        # shuffle the batches
        upper_key_list = list(range(len(upper_img)))
        random.shuffle(upper_key_list)
        
        # check if lower img exists, otherwise perform shuffling
        if lower_img == None:
            pass
        else:
            lower_key_list = list(range(len(lower_img)))
            random.shuffle(upper_key_list)
    else:
        upper_key_list = list(range(len(upper_img)))
        lower_key_list = list(range(len(lower_img)))
    
    
    # Only train on evenly split images
    if lower_img == None:
        num_subvolumes = len(upper_img)
        for bindex in trange(0, num_subvolumes, batch_size):
            if bindex + batch_size > num_subvolumes:
                # if the bindex surpasses the number of number of sub volumes
                batch_keys = upper_key_list[bindex:num_subvolumes]
            else:
                batch_keys = upper_key_list[bindex:bindex+batch_size]
            
            sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0)
            sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)
            
            optimizer.zero_grad()
            output = model(sub_imgs)
            probabilities = torch.softmax(output, 1)
            
            # discretize probability values 
            prediction = torch.argmax(probabilities, 1)
            tmp_upper_dict.update(dict(zip(batch_keys,prediction)))
            
            # calculate the loss for the current batch, save the loss per epoch to calculate the average running loss
            current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
            current_loss.backward()
            optimizer.step()
            running_loss += current_loss.item()
        
        # lower list does not exist
        tmp_lower_list = None
            
    # train on both 
    else:
        num_upper_subvolumes = len(upper_img)
        if augment:
            # Extract index of non-zero subvolumes
            upper_indexes = get_index_nonempty_cubes(upper_mask)
            
            # Augment on non-zero subvolumes based on their location in the volume (by index)
            for bindex in range(0, len(upper_indexes), batch_size):
                # for augmentation
                if bindex + batch_size > len(upper_indexes):
                    upper_batch = upper_indexes[bindex:len(upper_indexes)]
                else:
                    upper_batch = upper_indexes[bindex:bindex+batch_size]
                    
                sub_imgs, sub_masks = augmentation(all_transform, upper_img, upper_mask, upper_batch)
                sub_imgs, sub_masks = sub_imgs.to(device), sub_masks.to(device)
                optimizer.zero_grad()
                output = model(sub_imgs)
                probabilities = torch.softmax(output, 1)
                prediction = torch.argmax(probabilities, 1)
                
                current_loss = loss_function(probabilities, sub_masks)
                current_loss.backward()
                optimizer.step()
                running_loss += current_loss.item()
                count_loss += 1
        
        for bindex in range(0, num_upper_subvolumes, batch_size):
            if bindex + batch_size > num_upper_subvolumes:
                # if the bindex surpasses the number of number of sub volumes
                batch_keys = upper_key_list[bindex:num_upper_subvolumes]
            else:
                batch_keys = upper_key_list[bindex:bindex+batch_size]
            
            sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0) 
            sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)
            
            optimizer.zero_grad()
            output = model(sub_imgs) # predict the batches
            probabilities = torch.softmax(output, 1) 
            prediction = torch.argmax(probabilities,1)
            
            # update the upper img dictionary
            tmp_upper_dict.update(dict(zip(batch_keys,prediction)))
            
            current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
            
            current_loss.backward()
            optimizer.step()
            running_loss += current_loss.item()
            count_loss += 1
        
        num_lower_subvolumes = len(lower_img)
        if augment:
            
            lower_indexes = get_index_nonempty_cubes(lower_mask)
            
            for bindex in range(0, len(lower_indexes), batch_size):
                # for augmentation
                if bindex + batch_size > len(lower_indexes):
                    lower_batch = lower_indexes[bindex:len(lower_indexes)]
                else:
                    lower_batch = lower_indexes[bindex:bindex+batch_size]
                    
                sub_imgs, sub_masks = augmentation(all_transform, lower_img, lower_mask, lower_batch)
                sub_imgs, sub_masks = sub_imgs.to(device), sub_masks.to(device)
                optimizer.zero_grad()
                output = model(sub_imgs)
                probabilities = torch.softmax(output, 1)
                prediction = torch.argmax(probabilities, 1)
                
                current_loss = loss_function(probabilities, sub_masks)
                current_loss.backward()
                optimizer.step()
                running_loss += current_loss.item()
                count_loss += 1
                
        for bindex in range(0, num_lower_subvolumes, batch_size):
            if bindex + batch_size > num_lower_subvolumes:
                # if the bindex surpasses the number of number of sub volumes
                batch_keys = lower_key_list[bindex:num_lower_subvolumes]
            else:
                batch_keys = lower_key_list[bindex:bindex+batch_size]
            
            sub_imgs = torch.squeeze(torch.stack([lower_img.get(key) for key in batch_keys], dim=1), dim = 0) 
            sub_masks = torch.squeeze(torch.stack([lower_mask.get(key) for key in batch_keys], dim=1), dim = 0)
            
            optimizer.zero_grad()
            output = model(sub_imgs)
            probabilities = torch.softmax(output, 1)
            prediction = torch.argmax(probabilities,1)
            
            # update the lower dictionary
            tmp_lower_dict.update(dict(zip(batch_keys,prediction)))

            current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
            current_loss.backward()
            optimizer.step()
            running_loss += current_loss.item()
            count_loss += 1

        orig_shape = full_mask.shape[1:-1]
        reconstructed_mask_order = (3,0,1,2)


        upper_values = torch.stack([tmp_upper_dict[key] for key in list(range(len(tmp_upper_dict)))])
        lower_values = torch.stack([tmp_lower_dict[key] for key in list(range(len(tmp_lower_dict)))])


        reconstructed = reconstruct_training_masks(upper_values, lower_values, upper_shape, 
                                                   lower_shape, patch_size, orig_shape) # returns (z,y,x)
        reconstructed = to_categorical_torch(reconstructed, num_classes = 4) # returns (z,y,x,c)
        reconstructed = reconstructed.type(torch.int16)
        reconstructed = torch.permute(reconstructed, reconstructed_mask_order)
        reconstructed = torch.unsqueeze(reconstructed, 0) # make reconstructed image into (Batch,c,z,y,x)

#         full_mask = full_mask.type(torch.int16)
        gt_mask = torch.permute(full_mask, dim_order).cpu() # roll axis of grount truth mask into (batch,c,z,y,x)
        
    return {"batch_loss":running_loss/count_loss, "y_pred":reconstructed, "y":gt_mask}

In [16]:
def validate(engine, batch):
    
    batch_size = 64
    axial_steps = 16
    patch_size = (axial_steps, lateral_steps, lateral_steps)
    shuffle = False

    model.eval()
    with torch.no_grad():
        running_loss = 0
        count_loss = 0

        upper_img, upper_shape, lower_img, lower_shape, full_mask, upper_mask, lower_mask = batch
        # Empty list to place subvolumes in
        tmp_upper_dict = {}
        tmp_lower_dict = {}
        
        
        upper_key_list = list(range(len(upper_img)))
        lower_key_list = list(range(len(lower_img)))

        # Only train on evenly split images
        if lower_img == None:
            num_subvolumes = len(upper_img)
            for bindex in range(0, num_subvolumes, batch_size):
                if bindex + batch_size > num_subvolumes:
                    # if the bindex surpasses the number of number of sub volumes
                    batch_keys = upper_key_list[bindex:num_subvolumes]
                else:
                    batch_keys = upper_key_list[bindex:bindex+batch_size]
                
                sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0)
                sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)
                
                optimizer.zero_grad()
                output = model(sub_imgs)
                probabilities = torch.softmax(output, 1)

                # discretize probability values 
                prediction = torch.argmax(probabilities, 1)
                tmp_upper_dict.update(dict(zip(batch_keys,prediction)))

                # calculate the loss for the current batch, save the loss per epoch to calculate the average running loss
                current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                running_loss += current_loss.item()
                count_loss += 1

            # lower list does not exist
            tmp_lower_list = None

        # train on both 
        else:
            num_subvolumes = len(upper_img)
            for bindex in range(0, num_subvolumes, batch_size):
                if bindex + batch_size > num_subvolumes:
                    # if the bindex surpasses the number of number of sub volumes
                    batch_keys = upper_key_list[bindex:num_subvolumes]
                else:
                    batch_keys = upper_key_list[bindex:bindex+batch_size]
                
                sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0) 
                sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)

                optimizer.zero_grad()
                output = model(sub_imgs) # predict the batches
                probabilities = torch.softmax(output, 1) 
                prediction = torch.argmax(probabilities,1)

                current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                running_loss += current_loss.item()
                count_loss += 1

                # update the upper img dictionary
                tmp_upper_dict.update(dict(zip(batch_keys,prediction)))

            num_subvolumes = len(lower_img)
            for bindex in range(0, num_subvolumes, batch_size):
                if bindex + batch_size > num_subvolumes:
                    # if the bindex surpasses the number of number of sub volumes
                    batch_keys = lower_key_list[bindex:num_subvolumes]
                else:
                    batch_keys = lower_key_list[bindex:bindex+batch_size]
                
                sub_imgs = torch.squeeze(torch.stack([lower_img.get(key) for key in batch_keys], dim=1), dim = 0) 
                sub_masks = torch.squeeze(torch.stack([lower_mask.get(key) for key in batch_keys], dim=1), dim = 0)

                output = model(sub_imgs)
                probabilities = torch.softmax(output, 1)
                prediction = torch.argmax(probabilities,1)
                current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                running_loss += current_loss.item()
                count_loss += 1

                # update the lower dictionary
                tmp_lower_dict.update(dict(zip(batch_keys,prediction)))

            # return tmp_upper_list, tmp_lower_list, running_loss / count
    
        # neuron reconstruction to calculate the dice metric.
        orig_shape = full_mask.shape[1:-1]
        reconstructed_mask_order = (3,0,1,2)


        upper_values = torch.stack([tmp_upper_dict[key] for key in list(range(len(tmp_upper_dict)))])
        lower_values = torch.stack([tmp_lower_dict[key] for key in list(range(len(tmp_lower_dict)))])


        reconstructed = reconstruct_training_masks(upper_values, lower_values, upper_shape, 
                                                    lower_shape, patch_size, orig_shape) # returns (z,y,x)
        reconstructed = to_categorical_torch(reconstructed, num_classes = 4) # returns (z,y,x,c)
        reconstructed = torch.permute(reconstructed, reconstructed_mask_order)
        reconstructed = torch.unsqueeze(reconstructed, 0) # make reconstructed image into (Batch,c,z,y,x)

#         full_mask = full_mask.type(torch.int16).cpu()
        gt_mask = torch.permute(full_mask, dim_order).cpu() # roll axis of grount truth mask into (batch,c,z,y,x)
        
    return {"batch_loss":running_loss/count_loss, "y_pred":reconstructed, "y":gt_mask}

In [17]:
trainer = Engine(train)
evaluator = Engine(validate)

## Metrics and Progress Bars

In [18]:
# set up progress bar
from ignite.contrib.handlers import ProgressBar
from ignite.metrics import RunningAverage
# from ignite.metrics import ConfusionMatrix, DiceCoefficient
from monai.handlers.ignite_metric import IgniteMetric

def metric_output_transform(output):
    y_pred, y = output["y_pred"], output["y"]
    return y_pred, y

# cm = ConfusionMatrix(num_classes=4, output_transform = metric_output_transform)
# dice_metric = RunningAverage(DiceCoefficient(cm))
dice_metric = DiceMetric(include_background=False, reduction='mean_batch', get_not_nans = True)
metric = IgniteMetric(dice_metric, output_transform=metric_output_transform)

# progress output transform
def loss_output_transform(output):
    loss = output["batch_loss"]
    return loss

# Attach both metric to trainer and evaluator engine
metric.attach(trainer,"Dice")
metric.attach(evaluator,"Dice")

# RunningAverage(output_transform=loss_output_transform).attach(trainer, "batch_loss")
RunningAverage(output_transform=loss_output_transform).attach(evaluator, "batch_loss")
pbar = ProgressBar(persist=True)
pbar.attach(trainer, metric_names=["batch_loss"])
pbar.attach(evaluator, metric_names=["batch_loss"])

## Setup Model and Log Saving Directories

In [19]:
model_name = "ResUNET"
date = datetime.now(tz=pytz.utc).strftime('%Y%m%d')
time = datetime.now(tz=pytz.utc).strftime('%H%M%S')

model_directory = f"/home/jovyan/workspace/results/{model_name}/"
date_directory = f"/{date}/"
time_directory = f"{date}_{time}/"
log_directory = model_directory + date_directory + time_directory + "log"
os.makedirs(log_directory)

# create writer to log results into tensorboard
log_writer = SummaryWriter(log_directory)

In [20]:
import copy

@trainer.on(Events.STARTED)
def print_start(trainer):
    print("Training Started")

@trainer.on(Events.EPOCH_STARTED)
def print_epoch(trainer):
    print("Epoch : {}".format(trainer.state.epoch))
    
@trainer.on(Events.EPOCH_COMPLETED)
def save_model(trainer):
    global best_dice
    global best_epoch
    global best_epoch_file
    global best_loss
    
    epoch = trainer.state.epoch
    def get_saved_model_path(epoch):
        return model_directory + date_directory + time_directory + f"{model_name}_{epoch}.pth"

    # initialize global values
    best_dice = -torch.inf if epoch == 1 else best_dice
    best_loss = torch.inf if epoch == 1 else best_loss
    best_epoch = 1 if epoch == 1 else best_epoch
    best_epoch_file = '' if epoch == 1 else best_epoch_file
    
    def log_training_results(trainer):
        evaluator.run(training_dataloader)
        # Get engine metrics and losses
        training_metrics = copy.deepcopy(evaluator.state.metrics)
        pbar.log_message(
            "Training Results - Epoch: {} \nMetrics\n{}"
            .format(trainer.state.epoch, pprint.pformat(training_metrics)))
        return training_metrics
    
    def log_testing_results(trainer):
        evaluator.run(testing_dataloader)
        testing_metrics = copy.deepcopy(evaluator.state.metrics)
        scheduler.step(testing_metrics["batch_loss"])
        pbar.log_message(
            "Validation Results - Epoch: {} \nMetrics\n{}"
            .format(trainer.state.epoch, pprint.pformat(testing_metrics)))
        return testing_metrics
    
    training_metrics= log_training_results(trainer)
    testing_metrics= log_testing_results(trainer)
    
    train_dice = training_metrics['Dice']
    val_dice = testing_metrics['Dice']

    train_mean_dice = torch.mean(train_dice)
    val_mean_dice = torch.mean(val_dice)
    train_loss = training_metrics['batch_loss']
    val_loss = testing_metrics['batch_loss']
    

    # log results
    log_writer.add_scalars('Training vs. Validation Loss',
                       {'Training' : train_loss, 'Validation' : val_loss}, epoch)
    log_writer.add_scalars('Training vs. Validation Mean Dice ',
                       {'Training Mean Dice' : train_mean_dice, 'Validation Mean Dice' : val_mean_dice}, epoch)
    log_writer.add_scalars('Training vs. Validation Soma Dice ',
                       {'Training Soma Dice' : train_dice[0], 'Validation Soma Dice' : val_dice[0]}, epoch)
    log_writer.add_scalars('Training vs. Validation Dendrite Dice ',
                       {'Training Dendrite Dice' : train_dice[1], 'Validation Dendrite Dice' : val_dice[1]}, epoch)
    log_writer.add_scalars('Training vs. Validation Filopodias Dice ',
                       {'Training Filopodias Dice' : train_dice[2], 'Validation Filopodias Dice' : val_dice[2]}, epoch)
    log_writer.flush()

    if (testing_metrics['batch_loss'] < best_loss):
        
        # if there was a previous model saved, delete that one
        prev_best_epoch_file = get_saved_model_path(best_epoch)
        if os.path.exists(prev_best_epoch_file):
            os.remove(prev_best_epoch_file)

        # update the best mean dice and loss and save the new model state
#         best_dice = val_mean_dice
        best_loss = testing_metrics['batch_loss']
        best_epoch = epoch
        best_epoch_file = get_saved_model_path(best_epoch)
#         print(f'\nEpoch: {best_epoch} - New best Dice and Loss! Mean Dice: {best_dice} Loss: {best_loss}\n\n\n')
        print(f'\nEpoch: {best_epoch} - New best Loss! Loss: {best_loss}\n\n\n')
        torch.save(model.state_dict(), best_epoch_file)

## Early Stopping

In [21]:
# from ignite.handlers import EarlyStopping

# def score_function(engine):
#     val_loss = engine.state.metrics['batch_loss']
#     return -val_loss

# handler = EarlyStopping(patience=10, score_function=score_function, min_delta=1e-6, trainer=trainer)
# evaluator.add_event_handler(Events.COMPLETED, handler)

In [None]:
# Running Training Engine

trainer.run(training_dataloader, max_epochs = max_epochs)

In [None]:
trainer.state.metrics