## Import Modules and Check GPU

In [1]:
# import utility libraries
import os
import numpy as np
import gc
import datetime
import random

from processing_functions import *

# import deep learning libraries
from torchvision import transforms, utils
import tensorflow as tf
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange

# import processing libraries
from patchify import unpatchify
import skimage

In [2]:
torch.cuda.empty_cache()

In [3]:
# # cloud server
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# macbook
# # use_mps = torch.has_mps
# use_mps = False
# device = torch.device("mps" if use_mps else "cpu")

In [4]:
device

device(type='cuda', index=0)

## Read TIF images into Tensors

In [5]:
import os
import torchvision
from torch.utils.data import Dataset
import tifffile
import glob

# Cloud Server
raw_path = "/home/jovyan/workspace/Images/Raw/*.tif"
mask_path = "/home/jovyan/workspace/Images/Mask/*.tif"

# # Jason's Desktop
# raw_path = "H:\My Drive\Raw\*.tif"
# mask_path = "H:\My Drive\Mask\*.tif"

# # Jason's Macbook
# raw_path = "/Users/jasonfung/haaslabdataimages@gmail.com - Google Drive/My Drive/Raw/*.tif"
# mask_path = "/Users/jasonfung/haaslabdataimages@gmail.com - Google Drive/My Drive/Mask/*.tif"

raw_filename_list = glob.glob(raw_path) 
mask_filename_list = glob.glob(mask_path)

# Pre Shuffle
raw_filename_list.sort()
mask_filename_list.sort()

# Shuffle the filename list
from sklearn.utils import shuffle
raw_filename_list, mask_filename_list = shuffle(raw_filename_list, mask_filename_list, random_state = 42)


In [6]:
raw_filename_list

['/home/jovyan/workspace/Images/Raw/000_ML_20190604_B_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/00_ML_20180614_N1_50616967.tif',
 '/home/jovyan/workspace/Images/Raw/000_B_181107_A_N1B2_4a61736f.tif',
 '/home/jovyan/workspace/Images/Raw/00_ML_20180614_N5_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180621_A_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20190604_A_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_D_180907_A_N1B3_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_D_180906_A_N1C1_53696168.tif',
 '/home/jovyan/workspace/Images/Raw/060_B_181031_A_N1B3_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180613_N4_4a61736f.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180622_N2_50616967.tif',
 '/home/jovyan/workspace/Images/Raw/00_GFPC_20180615_N1_50616967.tif',
 '/home/jovyan/workspace/Images/Raw/000_GFP_181027_52616a61.tif',
 '/home/jovyan/workspace/Images/Raw/000_ML_20180622_N1_53696168.tif']

In [7]:
len(raw_filename_list)

14

## Processing Classes

## Patching and Reconstruction Functions

Input of an image is (# z stacks, # x pixels, # y pixels). Then it is split into sub volumes of size (z patch size, x patch size, y patch size) where each subvolume is (voxelized), meaning it has a coordinate relative to the original image. This ends up becoming a (z location, x location, y location, z patch size, x patch size, y patch size).

If the image has a depth that cannot be evenly split by 2^n depth patch size where n is an integer, the function patch_images() will return an "upper half" and a "lower half" of the image. For example, given an image size of (77,512,512) and a z patch size of 16, x patch size of 32, and y patch of 32, the image cannot be evenly split into 16ths, since 77 % 16 = 4 remainder 13. Therefore, starting from the top of the stack to slice 64: the upper half becomes (64, 512, 512). Contrastly, starting from the bottom of the stack at slice 77 to slice 61, the lower half becomes (16, 512, 512). The two split volumes will become merged in the end.

To prepare for training, the coordinated subvolumes are reshaped into "batch form". From the previous example with the upper half (64, 512, 512) split and coordinated into a 6D array (4, 16, 16, 16, 32, 32) -> (4 * 16 * 16, 16, 32, 32).


## Split Training and Testing Images

In [9]:
# define patching parameters
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
split_size = 0.8
dim_order = (0,4,1,2,3) # define the image and mask dimension order

patch_transform = transforms.Compose([
#                                       new_shape(new_xy = (600,960)),
                                      MinMaxScalerVectorized(),
                                      patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = False)])

# define transforms for labeled masks
label_transforms = transforms.Compose([
#                                        new_shape(new_xy = (600,960)),
                                       process_masks(int_class = 3),
                                       patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = True)])


raw_training_list, mask_training_list = raw_filename_list[:int(split_size*len(raw_filename_list))], mask_filename_list[:int(split_size*len(mask_filename_list))]
raw_testing_list, mask_testing_list = raw_filename_list[int(split_size*len(raw_filename_list)):], mask_filename_list[int(split_size*len(mask_filename_list)):]
print(len(raw_training_list))

training_data = MyImageDataset(raw_training_list,
                               mask_training_list,
                               transform = patch_transform,
                               label_transform = label_transforms,
                               device = device,
                               img_order = dim_order,
                               mask_order = dim_order,
                               num_classes = 4,
                               train=True)

testing_data = MyImageDataset(raw_testing_list,
                              mask_testing_list,
                              transform = patch_transform,
                              label_transform = label_transforms,
                              device = device,
                              img_order = dim_order,
                              mask_order = dim_order,
                              num_classes = 4,
                              train=False)

from torch.utils.data import DataLoader
training_dataloader = DataLoader(training_data, batch_size = 1, shuffle = True)
testing_dataloader = DataLoader(testing_data, batch_size = 1, shuffle = False)

11


# Training the Model

In [14]:
# from monai.losses import DiceLoss
from monai.losses import FocalLoss, DiceFocalLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.nets import BasicUNet, UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)

## Model and Model Parameters

### UNet

In [11]:
# set up loss and optimizer
max_epochs = 250
dropout = 0.10
learning_rate = 5e-5
decay = 1e-5
input_chnl = 1
output_chnl = 4

model = BasicUNet(spatial_dims=3, 
                  in_channels = input_chnl,
                  out_channels = output_chnl,
                  features = (16, 32, 64, 128, 256, 16),
                  norm = "batch",
                  dropout = dropout,
               )
model = model.to(device)

# loss_function = FocalLoss()
loss_function = DiceCELoss()
dice = DiceMetric()
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
# Instantiate Dice metric
dice = DiceMetric(include_background=True, reduction="mean", get_not_nans = True)
discretize = Compose([Activations(softmax = True), 
                      AsDiscrete(logit_thresh=0.5)])

BasicUNet features: (16, 32, 64, 128, 256, 16).


In [18]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))

4810977


## Augmentation

In [None]:
# # Augmentation Function parameters
# degree = (25, 5, 5)
# translate = (10,10,10)
# transform_rotate = torchio.RandomAffine(degrees=degree, translation=translate, image_interpolation="bspline")
# transform_flip = torchio.RandomFlip(axes=('ap',))
# all_transform = torchio.Compose([transform_rotate,
#                                  transform_flip])

# Augmentation Function parameters using resnet parameters
degree = (25, 5, 5)
# translate = (10,10,10)
transform_rotate = torchio.RandomAffine(degrees=degree, 
#                                         translation=translate, 
                                        image_interpolation="bspline")
transform_flip = torchio.RandomFlip(axes=('ap',))
all_transform = torchio.Compose([transform_rotate,
                                 transform_flip])


## Compose Dice Metric and Discretize Values

In [None]:
# Initialize metric lists to store Dice metrics
mean_dice = []
dice_soma = []
dice_dendrite = []
dice_filopodias = []

## Define Training and Validation Functions

In [None]:
def train(batch_size, model, patch_size, augment = True, shuffle = True):
    model.train()
    
    running_loss = 0
    count_loss = 0
    # Instantiate the dice sum for each class
    dice_mean = dice_background = dice_soma = dice_dendrite = dice_filopodias = 0.0
    dice_count = 0
    img_num = 0
    
    for upper_img, upper_shape, lower_img, lower_shape, full_mask, upper_mask, lower_mask in training_dataloader:
        
        img_num += 1
        print("Training Image: ", img_num)
        # Empty list to place subvolumes in
        tmp_upper_dict = {}
        tmp_lower_dict = {}    
        
        if shuffle == True:
            # shuffle the batches
            upper_key_list = list(range(len(upper_img)))
            random.shuffle(upper_key_list)
            
            # check if lower img exists, otherwise perform shuffling
            if lower_img == None:
                pass
            else:
                lower_key_list = list(range(len(lower_img)))
                random.shuffle(upper_key_list)
        else:
            upper_key_list = list(range(len(upper_img)))
            lower_key_list = list(range(len(lower_img)))
        
        
        # Only train on evenly split images
        if lower_img == None:
            num_subvolumes = len(upper_img)
            for bindex in trange(0, num_subvolumes, batch_size):
                if bindex + batch_size > num_subvolumes:
                    # if the bindex surpasses the number of number of sub volumes
                    batch_keys = upper_key_list[bindex:num_sub_volumes]
                else:
                    batch_keys = upper_key_list[bindex:bindex+batch_size]
                
                sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0)
                sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)
               
                optimizer.zero_grad()
                output = model(sub_imgs)
                probabilities = torch.softmax(ouput, 1)
                
                # discretize probability values 
                prediction = torch.argmax(probabilities, 1)
                tmp_upper_dict.update(dict(zip(batch_keys,prediction)))
                
                # calculate the loss for the current batch, save the loss per epoch to calculate the average running loss
                current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                
                current_loss.backward()
                optimizer.step()
                running_loss += current_loss.item()
                dice_count += 1
            
            # lower list does not exist
            tmp_lower_list = None
                
        # train on both 
        else:
            print("Training Upper Half of Image")
            num_upper_subvolumes = len(upper_img)
            if augment:
                print("Augmenting Images")
                upper_indexes = get_index_nonempty_cubes(upper_mask)

                for bindex in trange(0, len(upper_indexes), batch_size):
                    # for augmentation
                    if bindex + batch_size > len(upper_indexes):
                        upper_batch = upper_indexes[bindex:len(upper_indexes)]
                    else:
                        upper_batch = upper_indexes[bindex:bindex+batch_size]
                        
                    sub_imgs, sub_masks = augmentation(all_transform, upper_img, upper_mask, upper_batch)
                    
                    optimizer.zero_grad()
                    output = model(sub_imgs)
                    probabilities = torch.softmax(output, 1)
                    prediction = torch.argmax(probabilities, 1)
                    
                    current_loss = loss_function(probabilities, sub_masks)
                    current_loss.backward()
                    optimizer.step()
                    running_loss += current_loss.item()
                    count_loss += 1
            
            
            for bindex in trange(0, num_upper_subvolumes, batch_size):
                if bindex + batch_size > num_upper_subvolumes:
                    # if the bindex surpasses the number of number of sub volumes
                    batch_keys = upper_key_list[bindex:num_upper_subvolumes]
                else:
                    batch_keys = upper_key_list[bindex:bindex+batch_size]
                
                sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0) 
                sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)
                
                optimizer.zero_grad()
                output = model(sub_imgs) # predict the batches
                probabilities = torch.softmax(output, 1) 
                prediction = torch.argmax(probabilities,1)
                
                # update the upper img dictionary
                tmp_upper_dict.update(dict(zip(batch_keys,prediction)))
                
                current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                
                current_loss.backward()
                optimizer.step()
                running_loss += current_loss.item()
                count_loss += 1
            
            print("Training Lower Half of Image")
            num_lower_subvolumes = len(lower_img)
            if augment:
                print("Augmenting Lower Half Images")
                
                lower_indexes = get_index_nonempty_cubes(lower_mask)
                
                for bindex in trange(0, len(lower_indexes), batch_size):
                    # for augmentation
                    if bindex + batch_size > len(lower_indexes):
                        lower_batch = lower_indexes[bindex:len(lower_indexes)]
                    else:
                        lower_batch = lower_indexes[bindex:bindex+batch_size]
                        
                    sub_imgs, sub_masks = augmentation(all_transform, upper_img, upper_mask, upper_batch)
                    
                    optimizer.zero_grad()
                    output = model(sub_imgs)
                    probabilities = torch.softmax(output, 1)
                    prediction = torch.argmax(probabilities, 1)
                    
                    current_loss = loss_function(probabilities, sub_masks)
                    current_loss.backward()
                    optimizer.step()
                    running_loss += current_loss.item()
                    count_loss += 1
                    
            print("Non Augmented Image")
            for bindex in trange(0, num_lower_subvolumes, batch_size):
                if bindex + batch_size > num_lower_subvolumes:
                    # if the bindex surpasses the number of number of sub volumes
                    batch_keys = lower_key_list[bindex:num_lower_subvolumes]
                else:
                    batch_keys = lower_key_list[bindex:bindex+batch_size]
                
                sub_imgs = torch.squeeze(torch.stack([lower_img.get(key) for key in batch_keys], dim=1), dim = 0) 
                sub_masks = torch.squeeze(torch.stack([lower_mask.get(key) for key in batch_keys], dim=1), dim = 0)
                
                optimizer.zero_grad()
                output = model(sub_imgs)
                probabilities = torch.softmax(output, 1)
                prediction = torch.argmax(probabilities,1)
                
                # update the lower dictionary
                tmp_lower_dict.update(dict(zip(batch_keys,prediction)))

                current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                
                current_loss.backward()
                optimizer.step()
                running_loss += current_loss.item()
                count_loss += 1

        # neuron reconstruction to calculate the dice metric.
        orig_shape = full_mask.shape[1:-1]
        reconstructed_mask_order = (3,0,1,2)
        
        upper_values = torch.stack([tmp_upper_dict[key] for key in list(range(len(tmp_upper_dict)))])
        lower_values = torch.stack([tmp_lower_dict[key] for key in list(range(len(tmp_lower_dict)))])
        
        
        reconstructed = reconstruct_training_masks(upper_values, lower_values, upper_shape, 
                                                   lower_shape, patch_size, orig_shape) # returns (z,y,x)
        reconstructed = to_categorical_torch(reconstructed, num_classes = 4) # returns (z,y,x,c)
        reconstructed = torch.permute(reconstructed, reconstructed_mask_order)
        reconstructed = torch.unsqueeze(reconstructed, 0) # make reconstructed image into (Batch,c,z,y,x)
        
        gt_mask = torch.permute(full_mask, dim_order).cpu() # roll axis of grount truth mask into (batch,c,z,y,x)
        
        # compute the dice score for each class
        scores = dice(reconstructed, gt_mask)
        scores = scores[0]
        dice_mean += scores.mean()
        dice_background += scores[0]
        dice_soma += scores[1]
        dice_dendrite += scores[2]
        dice_filopodias += scores[3]
        
        dice_count += 1 
        
#         print(f'Training Dice: Mean = {dice_mean/dice_count}, Background = {dice_background/dice_count}, Soma = {dice_soma/dice_count}, Dendrite = {dice_dendrite/dice_count}, Filopodias = {dice_filopodias/dice_count}')
#         print(f'Training loss: {running_loss / count_loss}')
#         # print(f'Training Dice: {running_dice / count}')
#         writer.add_scalar('Training batch loss', running_loss / count_loss)
#         writer.flush()
    
#     scheduler.step()
    running_dice = [dice_background/dice_count, dice_soma/dice_count, dice_dendrite/dice_count, dice_filopodias/dice_count]
        
    return running_dice, running_loss / count_loss

In [None]:
def validate(batch_size, model, patch_size, shuffle = False):
    model.eval()
    with torch.no_grad():
        running_loss = 0
        count_loss = 0
        # Instantiate the dice sum for each class
        dice_mean = dice_background = dice_soma = dice_dendrite = dice_filopodias = 0.0
        dice_count = 0
        img_num = 0

        for upper_img, upper_shape, lower_img, lower_shape, full_mask, upper_mask, lower_mask in testing_dataloader:
            img_num += 1
            print("Training Image: ", img_num)
            # Empty list to place subvolumes in
            tmp_upper_dict = {}
            tmp_lower_dict = {}    

            if shuffle == True:
                # shuffle the batches
                upper_key_list = list(range(len(upper_img)))
                random.shuffle(upper_key_list)

                # check if lower img exists, otherwise perform shuffling
                if lower_img == None:
                    pass
                else:
                    lower_key_list = list(range(len(lower_img)))
                    random.shuffle(upper_key_list)
            else:
                upper_key_list = list(range(len(upper_img)))
                lower_key_list = list(range(len(lower_img)))


            # Only train on evenly split images
            if lower_img == None:
                num_subvolumes = len(upper_img)
                for bindex in trange(0, num_subvolumes, batch_size):
                    if bindex + batch_size > num_subvolumes:
                        # if the bindex surpasses the number of number of sub volumes
                        batch_keys = upper_key_list[bindex:num_sub_volumes]
                    else:
                        batch_keys = upper_key_list[bindex:bindex+batch_size]
                    
                    sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0)
                    sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)
                    
                    optimizer.zero_grad()
                    output = model(sub_imgs)
                    probabilities = torch.softmax(ouput, 1)

                    # discretize probability values 
                    prediction = torch.argmax(probabilities, 1)
                    tmp_upper_dict.update(dict(zip(batch_keys,sub_prediction)))

                    # calculate the loss for the current batch, save the loss per epoch to calculate the average running loss
                    current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                    running_loss += current_loss.item()
                    dice_count += 1

                # lower list does not exist
                tmp_lower_list = None

            # train on both 
            else:
                print("Validating Upper Half of Image")
                num_subvolumes = len(upper_img)
                for bindex in trange(0, num_subvolumes, batch_size):
                    if bindex + batch_size > num_subvolumes:
                        # if the bindex surpasses the number of number of sub volumes
                        batch_keys = upper_key_list[bindex:num_sub_volumes]
                    else:
                        batch_keys = upper_key_list[bindex:bindex+batch_size]
                    
                    sub_imgs = torch.squeeze(torch.stack([upper_img.get(key) for key in batch_keys], dim=1), dim = 0) 
                    sub_masks = torch.squeeze(torch.stack([upper_mask.get(key) for key in batch_keys], dim=1), dim = 0)

                    optimizer.zero_grad()
                    output = model(sub_imgs) # predict the batches
                    probabilities = torch.softmax(output, 1) 
                    prediction = torch.argmax(probabilities,1)

                    # update the upper img dictionary
                    tmp_upper_dict.update(dict(zip(batch_keys,prediction)))

                    current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                    running_loss += current_loss.item()
                    count_loss += 1

                print("Validating Lower Half of Image")
                num_subvolumes = len(lower_img)
                for bindex in trange(0, num_subvolumes, batch_size):
                    if bindex + batch_size > num_subvolumes:
                        # if the bindex surpasses the number of number of sub volumes
                        batch_keys = lower_key_list[bindex:num_sub_volumes]
                    else:
                        batch_keys = lower_key_list[bindex:bindex+batch_size]
                    
                    sub_imgs = torch.squeeze(torch.stack([lower_img.get(key) for key in batch_keys], dim=1), dim = 0) 
                    sub_masks = torch.squeeze(torch.stack([lower_mask.get(key) for key in batch_keys], dim=1), dim = 0)

                    output = model(sub_imgs)
                    probabilities = torch.softmax(output, 1)
                    prediction = torch.argmax(probabilities,1)

                    # update the lower dictionary
                    tmp_lower_dict.update(dict(zip(batch_keys,prediction)))

                    current_loss = loss_function(probabilities, sub_masks) # + dice_loss(predictions, patch_gt)
                    running_loss += current_loss.item()
                    count_loss += 1

                # return tmp_upper_list, tmp_lower_list, running_loss / count
        
            # neuron reconstruction to calculate the dice metric.
            orig_shape = full_mask.shape[1:-1]
            reconstructed_mask_order = (3,0,1,2)


            upper_values = torch.stack([tmp_upper_dict[key] for key in list(range(len(tmp_upper_dict)))])
            lower_values = torch.stack([tmp_lower_dict[key] for key in list(range(len(tmp_lower_dict)))])


            reconstructed = reconstruct_training_masks(upper_values, lower_values, upper_shape, 
                                                       lower_shape, patch_size, orig_shape) # returns (z,y,x)
            reconstructed = to_categorical_torch(reconstructed, num_classes = 4) # returns (z,y,x,c)
            reconstructed = torch.permute(reconstructed, reconstructed_mask_order)
            reconstructed = torch.unsqueeze(reconstructed, 0) # make reconstructed image into (Batch,c,z,y,x)

            gt_mask = torch.permute(full_mask, dim_order).cpu() # roll axis of grount truth mask into (batch,c,z,y,x)

            # compute the dice score for each class
            scores = dice(reconstructed, gt_mask)
            scores = scores[0]
            dice_mean += scores.mean()
            dice_background += scores[0]
            dice_soma += scores[1]
            dice_dendrite += scores[2]
            dice_filopodias += scores[3]

            dice_count += 1 

    #         print(f'Training Dice: Mean = {dice_mean/dice_count}, Background = {dice_background/dice_count}, Soma = {dice_soma/dice_count}, Dendrite = {dice_dendrite/dice_count}, Filopodias = {dice_filopodias/dice_count}')
    #         print(f'Training loss: {running_loss / count_loss}')
    #         # print(f'Training Dice: {running_dice / count}')
    #         writer.add_scalar('Training batch loss', running_loss / count_loss)
    #         writer.flush()

#         scheduler.step()
        
        running_dice = [dice_background/dice_count, dice_soma/dice_count, dice_dendrite/dice_count, dice_filopodias/dice_count]
        
        return running_dice, running_loss / count_loss

In [None]:
# # macbook log directory
# log_directory = '/Users/jasonfung/Documents/Masters Project/Results/runs/unet/trainer_{}'

# windows log directory
# log_directory = "\\Users\\Fungj\\Documents\\Results_{}"

# # cloud log and model directory
log_directory = "/home/jovyan/workspace/results/logs/unet/trainer_{}" 
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_directory.format(timestamp))

output_dir = '/home/jovyan/workspace/results/models/unet/'


best_val_loss = np.inf
batch_size = 32

for epoch in range(max_epochs):
    print(f' Epoch {epoch}: ')
    train_dice, train_loss = train(batch_size, model, patch_size)
    val_dice, val_loss = validate(batch_size, model, patch_size)
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: Loss train {train_loss}; Dice train {train_dice}; validation {val_loss}; validation IOU {val_dice}')

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                       {'Training' : train_loss, 'Validation' : val_loss}, epoch)
    writer.add_scalars('Training vs. Validation Mean Dice ',
                       {'Training Mean Dice' : np.mean(train_dice), 'Validation Mean Dice' : np.mean(val_dice)}, epoch)
    writer.add_scalars('Training vs. Validation Soma Dice ',
                       {'Training Soma Dice' : train_dice[1], 'Validation Soma Dice' : val_dice[1]}, epoch)
    writer.add_scalars('Training vs. Validation Dendrite Dice ',
                       {'Training Dendrite Dice' : train_dice[2], 'Validation Dendrite Dice' : val_dice[2]}, epoch)
    writer.add_scalars('Training vs. Validation Filopodias Dice ',
                       {'Training Filopodias Dice' : train_dice[3], 'Validation Filopodias Dice' : val_dice[3]}, epoch)
    writer.flush()
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model_path = output_dir + f'model_{timestamp}_{epoch}'
        torch.save(model.state_dict(), model_path)