## 01 UNet Training

The UNet model is used to distinguish background from forground i.e. DAPI nuclei from cytosol or intracellular space. To ensure the most accurate prediction, ideally only images that are similar to the later predictions are trained on. Therefore, the data science bowl is disregarded and only home-labeled nuclei are used.

In [None]:
import glob
import skimage
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import utils.data_provider
import utils.dirtools
import utils.evaluation
import utils.model_builder
import utils.metrics
import utils.objectives

### Image import

Assumes all images are in the 'tif' file format and are grayscale. As the detection will only occur on 'normal' DAPI labelled nuclei, all images are also assumed to have black as background and white as signal. Signal density is not deterministic. As the same images are used for the stardist prediction later, each nuclei should be labeled uniquely.

In [None]:
root = './data/train_val/'
img_size = 256

In [None]:
# Import paths
X = sorted(glob.glob(f'{root}images/*.tif'))
Y = sorted(glob.glob(f'{root}masks/*.tif'))

In [None]:
# Train / valid split
x_train, x_valid, y_train, y_valid = utils.dirtools.train_valid_split(x_list=X, y_list=Y, valid_split=0.2)

In [None]:
# Sanity check
ix = np.random.randint(0, len(X)-1)

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].imshow(skimage.io.imread(X[ix]))
ax[0].set_title(f'Original Image – #{ix}')
ax[1].imshow(skimage.io.imread(Y[ix]))
ax[1].set_title('Ground Truth')
plt.show()

### Build model

In [None]:
# Build model
tf.keras.backend.clear_session()
model = utils.model_builder.standard_unet()
model.summary()

# Compile model
# loss = utils.objectives.weighted_crossentropy
loss = tf.keras.losses.categorical_crossentropy
metrics = [tf.keras.metrics.categorical_accuracy]
optimizer = tf.keras.optimizers.Adam(lr=0.0001)

model.compile(loss=loss, metrics=metrics, optimizer=optimizer)

# Callbacks
model_name = f"./models/{datetime.date.today().strftime('%Y%m%d')}_UNet"
callbacks = [utils.metrics.PlotLosses(),
             tf.keras.callbacks.ModelCheckpoint(f'{model_name}.h5', save_best_only=True),
             tf.keras.callbacks.CSVLogger(filename=f'{model_name}.csv'),
             tf.keras.callbacks.TensorBoard(model_name)]

### Training

In the `./model` directory, access tensorboard via `tensorboard --logdir=.` and access via [localhost:6006](localhost:6066).

In [None]:
# Build generators
train_gen = utils.data_provider.random_sample_generator(
    x_list=x_train,
    y_list=y_train,
    batch_size=16,
    bit_depth=16,
    dim_size=img_size)

val_gen = utils.data_provider.single_data_from_images(
    x_valid,
    y_valid,
    batch_size=16,
    bit_depth=16,
    dim_size=img_size)

In [None]:
# Training
statistics = model.fit_generator(generator=train_gen,
                                 steps_per_epoch=20,
                                 epochs=5,
                                 validation_data=val_gen,
                                 validation_steps=20,
                                 callbacks=callbacks,
                                 verbose=2)

model.save_weights(f'{model_name}_final.h5')