Code for training U-Net model, adapted from GitHub repository: 
"https://github.com/milesial/Pytorch-UNet"                   
Identical to train.py file but compatible with jupyter notebook

In [None]:
# import libraries
import argparse
import logging
import os
import sys

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from sklearn.model_selection import KFold

from eval import eval_net
from unet import UNet

from datetime import datetime
now = datetime.now()

from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from utils.augmentations import AugDataset
from torch.utils.data import DataLoader, random_split

In [None]:
# specify directories for images, masks, checkpoints and cross-validation data
dir_img = './data/train/imgs/'
dir_mask = './data/train/masks/'
dir_checkpoint = './checkpoints/'
dir_cv ='./cv/'

In [None]:
# funtion for training the NN
def train_net(net, device, dataset, train_set, val_set, fold = 0,  epochs=5, batch_size=1, 
              lr=0.0001, save_cp=True, img_scale=0.5, data_aug = 5):
    
    # track number of batches passed through the network
    global_step = 0
    # function will output txt file with the performance of the model on each epoch and fold
    out_txt.write(f'Fold {fold +1} \n') # write fold number in file
    
    # Define data loaders for training and validation set in this fold
    train_loader = DataLoader(AugDataset(train_set,num = data_aug), batch_size=batch_size, pin_memory=True, num_workers=0)
    val_loader = DataLoader(AugDataset(val_set, transform=None), batch_size=batch_size, pin_memory=True, num_workers=0)

    # initialize optimizer
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss() # use CrossEntropyLoss for multi-class segmentation
    else:
        criterion = nn.BCEWithLogitsLoss() # use BCEWithLogitsLoss for binary segmentation

    # run the training loop for defined number of epochs
    for epoch in range(epochs):

        net.train()

        # set initial loss value
        epoch_loss = 0
        # display epoch information and initiate training progress bar
        with tqdm(total=len(train_loader.dataset), desc=f'Fold {fold +1}, Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            # iterate over the DataLoader with training data
            for batch in train_loader:
                # define images and masks
                imgs = batch['image']
                true_masks = batch['mask']
                # ensure that number of channels in input images are the same as the defined number of channels
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'
                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                # perform forward pass
                masks_pred = net(imgs)

                # compute loss
                loss = criterion(masks_pred, true_masks)
                # update loss
                epoch_loss += loss.item()
                # save loss to summary
                writer.add_scalar('Loss/train', loss.item(), global_step)
                # display loss
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # zero the gradients
                optimizer.zero_grad()

                # perform backward pass
                loss.backward()
                # gradient clipping
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                # perform optimization
                optimizer.step()
                
                # update progress bar based on how many images have passed through the model
                pbar.update(imgs.shape[0])
                global_step += 1
                # save weights and biases to summary
                if global_step % (len(train_loader.dataset) // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    # evaluate model
                    val_score = eval_net(net, val_loader, device)
                    # decay learning rate
                    scheduler.step(val_score)
                    # save learning rate to summary
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                    # print validation cross entropy or Dice coefficient depending on mask classes
                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step) # save loss
                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step) # save loss
                    
                    # add batch image data to summary
                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks, global_step)
                        writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
        
        # save checkpoint details
        if save_cp:
            try:
                # create checkpoint directory if there isn't one
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            # name checkpoint file
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_fold{fold +1}_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1}, fold {fold +1} saved !')
        
        # write dice score of each epoch in output txt file  
        out_txt.write(f'Epoch{epoch +1}: Dice Coeff {val_score} \n')
    # write dice scores of next fold in a new line
    out_txt.write('\n')

In [None]:
# pass arguments
class empty: pass
args = empty()
# number of epochs
args.epochs = 5
# load model from a .pth file
args.load = False
# batch size
args.batch_size = 1 # batch size was set to 1 as training images had slightly different dimentions
# learning rate
args.lr = 0.0001
# image scale
args.img_scale = 0.5
# number of folds in k-fold cross-validation
args.k_folds = 5
# number of augmented images generated for each image in the train set
args.data_aug = 2
# whether checkpoint will be saved
args.save_cp= True

In [None]:
# initialise logging INFO, for displaying training information
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
# set device to cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# display device used
logging.info(f'Using device {device}')

# convert images and masks into tensors
dataset = BasicDataset(dir_img, dir_mask, args.img_scale)

# writer will output to ./runs/ directory by default
writer = SummaryWriter(comment=f'LR_{args.lr}_BS_{args.batch_size}_SCALE_{args.img_scale}')
# print summary of input parameters
logging.info(f'''Starting training:
    Epochs:          {args.epochs}
    Batch size:      {args.batch_size}
    Learning rate:   {args.lr}
    Training size:   {len(dataset)}
    K-folds:         {args.k_folds}
    Checkpoints:     {args.save_cp}
    Device:          {device.type}
    Images scaling:  {args.img_scale}
    Data augmentations per image: {args.data_aug}
''')


# define the K-fold Cross Validator
kfold = KFold(n_splits=args.k_folds, shuffle=False)

# create txt file to save cross-validation data (variable named out_txt)
try:
    os.mkdir(dir_cv) # create directory for saving the cross-validation output txt file
except OSError:
    pass  
# create and name output txt file 
out_txt= open(dir_cv + f'{now.strftime("%b-%d-%Y_%H-%M-%S")}_LR_{args.lr}_BS_{args.batch_size}_SCALE_{args.img_scale}.txt', 'w')

# K-fold Cross Validation model evaluation
for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):

    # display fold number
    logging.info(f'''FOLD {fold +1}''')
    
    # initialise U-Net model and define number of channels, number of classes and up-scaling technique
    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    
    # if bilinear interpolation not indicated use transposed conv. up-scaling 
    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
 
    net.to(device=device)
    
    
    # Dividing data into folds
    train_set = torch.utils.data.dataset.Subset(dataset,train_idx)
    val_set = torch.utils.data.dataset.Subset(dataset,val_idx)
    
    # train model
    try:
        train_net(net=net,
                  dataset = dataset,
                  train_set = train_set,
                  val_set = val_set,
                  fold = fold,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  lr=args.lr,
                  device=device,
                  img_scale=args.img_scale,
                  data_aug=args.data_aug,)
    
    # close checkpoint file and output txt if training interrupted 
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        out_txt.close()
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
            
# close checkpoint file and output txt when training is completed           
out_txt.close()
writer.close()