# Training notebook for semantic segmentation of embryos in microscopy images

A notebook outlining the training procedure using a variety of popular semantic segmentation models such as UNet and DeepLab V3. Because this application only includes single instances of embryos in any given microscopy images, we will frame this as a binary segmentation task. Though if your application includes multiple instances of embryos in a given image, you will need to instead make use of instance segmentation models such as Mask-RCNN.


## Dependencies

The following are required dependencies for this script. We also set up mixed precision training for the speedup it provides in training time.

In [None]:
import glob
import vuba
import matplotlib.pyplot as plt
import random
import numpy as np
import cv2
import vuba
from tensorflow import keras
import re
from tensorflow.keras import layers
from tqdm import tqdm
import time
import tensorflow as tf
from natsort import ns, natsorted
import pandas as pd

from seg_models import build_model

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Parameters ----------------------------------------------------------
batch_size = 32
w,h = 256, 256
epochs = 10
model_save_dir = './trained_models'
# ---------------------------------------------------------------------

## Dataset pipeline

The following dataset pipeline is using an augmented dataset generated from manually annotated images - all source images were annotated using the VGG Image Annotator, or VIA. Here, we have simply used tensorflow's dataset pipeline to batch process this dataset out-of-core to overcome memory limitations on our system. Note that images are rescaled by default in the model so images can be supplied in uint8 format, though masks must be normalised to 0 - 1 scale.

In [None]:
# Data prep ------------------------------------------------------
def read_img(im, an):
    img = tf.io.read_file(im)
    img = tf.image.decode_png(img, channels=1)
    img.set_shape([None, None, 1])
    img = tf.image.resize_with_pad(img, 256, 256)

    ann = tf.io.read_file(an)
    ann = tf.image.decode_png(ann, channels=1)
    ann.set_shape([None, None, 1])
    ann = tf.image.resize_with_pad(ann, 256, 256)

    return img, ann

def dataset(img_files, annot_files, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((img_files, annot_files))
    dataset = dataset.map(read_img, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

train_images = glob.glob('./train_images/*.png')
random.shuffle(train_images)
train_annotations = [re.sub('./train_images/', './train_annotations/', f) for f in train_images]

val_images = glob.glob('./val_images/*.png')
val_annotations = [re.sub('./val_images/', './val_annotations/', f) for f in val_images]

test_images = glob.glob('./test_images/*.png')
test_annotations = [re.sub('./test_images/', './test_annotations/', f) for f in test_images]

train_data = dataset(train_images, train_annotations, batch_size)
val_data = dataset(val_images, val_annotations, batch_size)
test_data = dataset(test_images, test_annotations, batch_size)

# Display a single example
for b in train_data:
    im, ann = b

    for i,a in zip(im, ann):
        print(i.numpy().ptp(), a.numpy().ptp())
        fig, (ax1, ax2) = plt.subplots(1, 2)
        ax1.imshow(i, cmap='gray')
        ax2.imshow(a, cmap='gray')
        ax1.set_title('Input Image')
        ax2.set_title('Ground Truth Mask')
        plt.show()
        break
    break


## Training and evaluation

This is the main training loop for constructing, training and computing summary metrics for each model. For each model we use the Adam optimizer as well as Binary Cross Entropy for loss given that this is a binary segmentation problem.

In [None]:
# Model training -------------------------------------------------
model_stats = dict(
    name=[],
    loss=[],
    val_loss=[],
    test_loss=[],
    binary_io_u=[],
    val_binary_io_u=[],
    test_binary_io_u=[],
    params=[],
    total_train_time=[],
    per_epoch_time=[],
    per_step_time=[]
)

model_defs = ['UNet', 'UNet2plus', 'UNet3plus', 'SegNet', 'FCN', 'PSPNet', 'DeepLabV3', 'HRNetV2']

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

for model_name in model_defs:
    print('---------------------------------')
    print(f'Training: {model_name}')
    print('---------------------------------')        
        
    MODEL = build_model(input_shape=(w, h, 1), n_classes=1, model=model_name)

    loss = keras.losses.BinaryCrossentropy(from_logits=False)
    MODEL.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        loss=loss,
        metrics=[keras.metrics.BinaryIoU([0, 1], 0.8)],
    )
    
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=f'./trained_models/{model_name}_lymnaea_binary.h5',
            save_best_only=True
        )
    ]
    
    # Train the model, doing validation at the end of each epoch.
    start = time.time()
    history = MODEL.fit(
        train_data,
        epochs=epochs, 
        callbacks=callbacks,
        batch_size=batch_size,
        validation_data=val_data)
    end = time.time()

    MODEL.load_weights(f'./trained_models/{model_name}_lymnaea_binary.h5')
    
    test_loss, test_accuracy = MODEL.evaluate(test_data)
    print(MODEL.count_params())
    
    model_stats['name'].append('UNet')
    model_stats['loss'].append(max(history.history['loss']))
    model_stats['val_loss'].append(max(history.history['val_loss']))
    model_stats['test_loss'].append(test_loss)
    model_stats['binary_io_u'].append(max(history.history['binary_io_u']))
    model_stats['val_binary_io_u'].append(max(history.history['val_binary_io_u']))
    model_stats['test_binary_io_u'].append(test_accuracy)
    model_stats['params'].append(MODEL.count_params())
    model_stats['total_train_time'].append(end - start)
    model_stats['per_epoch_time'].append((end - start) / epochs)
    model_stats['per_step_time'].append(((end - start) / epochs) / len(train_data))
    
    ax1.plot(history.history['loss'], label=model_name)
    ax2.plot(history.history['val_loss'], label=model_name)
    ax1.set_title('Training loss')
    ax2.set_title('Validation loss')
    
    ax3.plot(history.history['binary_io_u'], label=model_name)
    ax4.plot(history.history['val_binary_io_u'], label=model_name)
    ax3.set_title('Training Binary IoU')
    ax4.set_title('Validation Binary IoU')
    
    del MODEL
    keras.backend.clear_session()

plt.legend(loc='lower right')
plt.show()

# Summary statistics
df = pd.DataFrame(data=model_stats)
df.to_csv('./model_stats_lymnaea.csv')