# Train U-net for punctae segmentation

Imports:

In [None]:
import numpy as np
import h5py
import cv2
from os import walk, makedirs
from os.path import join, exists
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
import segmentation_models as sm
from tqdm import tqdm
from segmentation_models.metrics import f1_score, iou_score
from skimage import filters
import glob
import re
import pickle
from datetime import datetime

In [None]:
sm.set_framework('tf.keras')

Set training parameters:

In [None]:
batch_size = 10
learning_rate = 0.0001
num_epochs = 200
model_name = 'punctae_seg_model'
seed = 42

## Loading image and masks
Load pre-compiled image data and the punctae dot masks

In [None]:
data = h5py.File('data/compiled_punctae_seg_data_final.h5', 'r')

In [None]:
X_train, X_val, y_train, y_val = data['X_train'], data['X_val'], data['y_train'], data['y_val']
print(X_train.shape, X_train.dtype)
print(X_val.shape, X_val.dtype)
print(y_train.shape, y_train.dtype)
print(y_val.shape, y_val.dtype)

## Capability for restarting training.
Load previous weights if they exist or start from scratch

In [None]:
model_weights_path = None
start_from_epoch = 0
if not exists(model_name):
    print('Training from scratch.')
    makedirs(model_name)
else:
    models = glob.glob(join(model_name, '*.h5'))
    if len(models) == 0:
        print('No models were stored. Training from scratch.')
    else:
        model_weights_path = max(glob.glob(join(model_name, '*.h5')), key=lambda x: int(re.findall('\.(\d{3})-', x)[0]))
        start_from_epoch = int(re.findall('\.(\d{3})-', model_weights_path)[0])
        print('Starting from checkpoint %s. Epoch=%i (one-indexed)' % (model_weights_path, start_from_epoch))

## Definition of the U-net model structure

In [None]:
model = sm.Unet( 
    'resnet34',
    classes=1,
    activation='sigmoid',
    encoder_weights='imagenet',
    input_shape=(X_train.shape[1:-1]) + (3,),
    decoder_block_type='transpose'
)

image_input = tf.keras.Input(shape=(X_train.shape[1:-1]) + (2,), dtype=tf.float32, name='image_input')
image_repeat = tf.keras.layers.Conv2D(filters=3, kernel_size=(3,3), padding='same')(image_input)

posterior = model(image_repeat)

model = tf.keras.Model(inputs=image_input, outputs=posterior)

print(model.summary())

## Compile model with DICE loss
Define soft DICE loss and DICE for validation

In [None]:
def my_dice(y_true, y_pred):
    
    y_true_f = tf.cast(tf.where(y_true > 0.01, 1, 0), tf.float32)
    y_pred_f = tf.cast(tf.where(y_pred > 0.01, 1, 0), tf.float32)
    
    y_true_f = K.flatten(y_true_f)
    y_pred_f = K.flatten(y_pred_f)
    
    intersection = K.sum(y_true_f * y_pred_f)
    
    dice = (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)
    return dice

def my_dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    dice = (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1-dice

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate = learning_rate),
    loss=my_dice_loss,
    metrics=[my_dice]
)

Load the weights:

In [None]:
if model_weights_path is not None:
    print('Loading weights %s' % model_weights_path)
    model.load_weights(model_weights_path)

## Create online augmentation object for image-mask pairs

Augmentation for training:

In [None]:
data_gen_args_img = dict(
    rotation_range=15,
    horizontal_flip=True,
    vertical_flip=True,
    shear_range=15,
    brightness_range=[0.8, 1.2],
    zoom_range=[0.8, 1.2],
    fill_mode='nearest',
    preprocessing_function = lambda x: x,
)

data_gen_args_msk = dict(
    rotation_range=15,
    horizontal_flip=True,
    vertical_flip=True,
    shear_range=15,
    brightness_range=[0.8, 1.2],
    zoom_range=[0.8, 1.2],
    fill_mode='nearest',
    # Mask dtype becomes float. Need to cast to int again.
    preprocessing_function = lambda x: np.where(x > filters.threshold_otsu(x), 1, 0) if not np.all(x == 0) else np.where(x > 0, 1, 0),
)

image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**data_gen_args_img) 
mask_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**data_gen_args_msk)

image_datagen.fit(X_train, augment=True, seed=seed)
mask_datagen.fit(y_train, augment=True, seed=seed)

train_image_generator = image_datagen.flow(X_train, batch_size=batch_size, seed=seed, shuffle=True)
train_mask_generator = mask_datagen.flow(y_train, batch_size=batch_size, seed=seed, shuffle=True)

train_generator = zip(train_image_generator, train_mask_generator)

Create the same object for validation set, however without augmentation:

In [None]:
data_gen_args_img_val = dict(
    rotation_range=0,
    horizontal_flip=False,
    vertical_flip=False,
    shear_range=0,
    brightness_range=[1.0, 1.0],
    zoom_range=[1.0, 1.0],
    fill_mode='reflect',
    preprocessing_function = lambda x: x,
)

data_gen_args_msk_val = dict(
    rotation_range=0,
    horizontal_flip=False,
    vertical_flip=False,
    shear_range=0,
    brightness_range=[1.0, 1.0],
    zoom_range=[1.0, 1.0],
    fill_mode='nearest',
    # Mask dtype becomes float. Need to cast to int again.
    preprocessing_function = lambda x: np.where(x > filters.threshold_otsu(x), 1, 0) if not np.all(x == 0) else np.where(x > 0, 1, 0),
)

val_image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**data_gen_args_img_val)
val_mask_datagen = tf.keras.preprocessing.image.ImageDataGenerator(**data_gen_args_msk_val)

val_image_datagen.fit(X_val, augment=False, seed=seed)
val_mask_datagen.fit(y_val, augment=False, seed=seed)

val_image_generator = val_image_datagen.flow(X_val, batch_size=batch_size, seed=seed, shuffle=False)
val_mask_generator = val_mask_datagen.flow(y_val, batch_size=batch_size, seed=seed, shuffle=False)

validation_generator = zip(val_image_generator, val_mask_generator)

## Train the model

In [None]:
history = model.fit(
    train_generator,
    steps_per_epoch = np.ceil(len(X_train) / batch_size),
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=validation_generator,
    validation_steps= np.ceil(len(X_val) / batch_size),
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint(filepath=join(model_name, 'model.{epoch:03d}-{val_my_dice:.4f}.h5'))
    ],
    initial_epoch=start_from_epoch
)

Make sure to save training history:

In [None]:
a_file = open("punctae_training_history.pkl", "wb")
pickle.dump(history.history, a_file)
a_file.close()
training_history = history.history

Store training history (DICE) to pdf:

In [None]:
epochs = np.arange(len(training_history['loss']))

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
%config InlineBackend.figure_format = 'retina'

fig = plt.figure(figsize=(15,8))
plt.plot(epochs, training_history['my_dice'], label='Training')
plt.plot(epochs, training_history['val_my_dice'], label='Validation')

plt.ylabel('DICE', fontsize=20)
plt.xlabel('Epoch', fontsize=20)
plt.rc('xtick', labelsize=15)
plt.rc('ytick', labelsize=15)
plt.xlim((0, 200))
plt.ylim((0, 1.0))
plt.grid()
plt.legend(loc=4, prop={'size': 20})
plt.savefig('convergence_punctae.pdf')