In [None]:
from __future__ import division, print_function
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import tensorflow as tf
import glob
import os
import random
from skimage import io
tf.executing_eagerly()

np.random.seed(300)
plt.rcParams['image.cmap'] = 'gist_earth'

In [None]:
import unet

%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

In [None]:
def running_mean(x, n):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[n:] - cumsum[:-n]) / float(n)

In [None]:
def plot_history(history):
    # plot training and validation loss and binary accuracy
    
    loss = running_mean(history.history['loss'], 9)
    val_loss = running_mean(history.history['val_loss'], 9)
    #epochs = len(history.history['loss'])
    epochs = len(loss)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(range(0, epochs), loss , label='loss')
    ax1.plot(range(0, epochs), val_loss, label='val_loss')
    ax1.set_title('train and validation loss')
    ax1.legend(loc='upper right')
    
    acc = running_mean(history.history['binary_accuracy'], 9)
    val_acc = running_mean(history.history['val_binary_accuracy'], 9)

    ax2.plot(range(0, epochs), acc, label='binary_accuracy')
    ax2.plot(range(0, epochs), val_acc, label='val_binary_accuracy')
    ax2.set_title('train and validation binary accuracy')
    ax2.legend(loc='lower right')

    plt.show()

In [None]:
def show_predictions(raw, gt, pred):
    
    thresh = 0.9
    max_values = np.max(pred[:,0], axis=(1, 2))
    if np.any(max_values < thresh):
        print("Heads up: If prediction is below {} then the prediction map is shown.".format(thresh))
        print("Max predictions: {}".format(max_values))
    
    num_samples = pred.shape[0]
    fig, ax = plt.subplots(num_samples, 3, sharex=True, sharey=True, figsize=(12, num_samples * 4))
    for i in range(num_samples):
        ax[i, 0].imshow(raw[i,0], aspect="auto")
        ax[i, 1].imshow(gt[i,0], aspect="auto")
        # check for prediction threshold
        if np.sum(max_values[i]) < thresh:
            ax[i, 2].imshow(pred[i,0], aspect="auto")
        else:
            ax[i, 2].imshow(pred[i,0] >= thresh, aspect="auto")

    ax[0, 0].set_title("Input")
    ax[0, 1].set_title("Ground truth")
    ax[0, 2].set_title("Prediction")
    fig.tight_layout()

## (1) Load and visualize our toy data examples:

In [None]:
# load tif images and reformat the way keras requires it
def load_dataset(in_folder):
    x = []
    y = []
    raw_files = glob.glob(in_folder + '/raw_*.tif')
    for raw_file in raw_files:
        x.append(io.imread(raw_file))
        y.append(io.imread(raw_file.replace('raw', 'gt')))
    x = np.array(x)[:, np.newaxis]
    y = np.array(y)[:, np.newaxis]
    return x, y

In [None]:
# load data into train/val/test sets
x_train, y_train = load_dataset('example_toy_data/train')
x_val, y_val = load_dataset('example_toy_data/val')
x_test, y_test = load_dataset('example_toy_data/test')

In [None]:
# show training examples
num_samples = 3
fig, ax = plt.subplots(num_samples, 2, sharey=True, figsize=(8, num_samples * 4))
for i in range(num_samples):
    ax[i, 0].imshow(x_train[i,0], aspect="auto")
    ax[i, 1].imshow(y_train[i,0], aspect="auto")
ax[0, 0].set_title("Input")
ax[0, 1].set_title("Ground truth")
fig.tight_layout()

## (2) Create and train our model

In [None]:
# define input shape
net_input = tf.keras.Input(shape=(1, 512, 512), name='img')

# define activation function
activation = tf.keras.layers.Activation("sigmoid")

# create unet with parameters: input, # output channel, unet depth, # fmaps
net_output, receptive_field = unet.unet(net_input, 1, 2, 32, activation=activation)

# instantiate the model
net = tf.keras.Model(net_input, net_output, name='unet')

# print network layers
net.summary()
print("Receptive field: ", receptive_field)

### Receptive Field of View

The number of convolutions and the depth of the U-Net are the major factors in determining the 
receptive field of the network. The term is borrowed from biology where it describes the "portion of sensory space that can elicit neuronal responses when stimulated" (wikipedia). Each output pixel can look at/depends on an input patch with that diameter centered at its position.
Based on this patch, the network has to be able to make a decision about the prediction for the respective pixel.
Yet larger sizes increase the computation time significantly.

The following code snippet visualizes the field of view of the center pixel for networks with varying depth:

In [None]:
idx = random.randrange(len(x_train))
out_channels = 1
images = x_train[idx]
rnd = random.randrange(len(images))
image = images[rnd]
#label = labels[rnd]

net_input_t = tf.keras.Input(shape=(1, 512, 512), name='img')
net_t = net_input_t
#net_t = tf.keras.layers.ZeroPadding2D(12, data_format='channels_first')(net_input_t)
fovs = []
_, fov_tmp = unet.unet(net_t, out_channels, depth=1, num_fmaps=32)
fovs.append(fov_tmp)
_, fov_tmp = unet.unet(net_t, out_channels, depth=2, num_fmaps=32)
fovs.append(fov_tmp)
_, fov_tmp = unet.unet(net_t, out_channels, depth=3, num_fmaps=32)
fovs.append(fov_tmp)
_, fov_tmp = unet.unet(net_t, out_channels, depth=4, num_fmaps=32)
fovs.append(fov_tmp)
_, fov_tmp = unet.unet(net_t, out_channels, depth=5, num_fmaps=32)
fovs.append(fov_tmp)

fig=plt.figure(figsize=(8, 8))
colors = ["yellow", "red", "green", "blue", "magenta"]
plt.imshow(np.squeeze(image), cmap='gray')
for idx, fov_t in enumerate(fovs):
    print("Field of view at depth {}: {:3d} (color: {})".format(idx+1, fov_t, colors[idx]))
    xmin = image.shape[1]/2 - fov_t/2
    xmax = image.shape[1]/2 + fov_t/2
    ymin = image.shape[1]/2 - fov_t/2
    ymax = image.shape[1]/2 + fov_t/2
    plt.hlines(ymin, xmin, xmax, color=colors[idx], lw=3)
    plt.hlines(ymax, xmin, xmax, color=colors[idx], lw=3)
    plt.vlines(xmin, ymin, ymax, color=colors[idx], lw=3)
    plt.vlines(xmax, ymin, ymax, color=colors[idx], lw=3)
plt.show()

In [None]:
# specify the training configuration (optimizer, loss, metrics)
net.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)]
)

In [None]:
# train the model, takes ~1:20min
history = net.fit(x=x_train, y=y_train, batch_size=4, epochs=60, validation_data=(x_val, y_val))
print('Finished Training')

In [None]:
# plot loss and accuracy
plot_history(history)

## (3) Test and evaluate our model

In [None]:
# evaluate our model performance on the test set
results = net.evaluate(x=x_test, y=y_test)

In [None]:
# predict the test set
predictions = net.predict(x=x_test)

# plot predicted results
show_predictions(x_test, y_test, predictions)

### A1: Continue training for more epochs

In [None]:
# continue training, takes ~3min
# heads up: the "net" variable still carries all the information from the previous training
history_continued = net.fit(x=x_train, y=y_train, batch_size=4, epochs=160, validation_data=(x_val, y_val), 
                            initial_epoch=60)
print('Finished Training')

In [None]:
# append both histories
for k in history.history.keys():
    history.history[k] = history.history[k] + history_continued.history[k]

In [None]:
# plot loss and accuracy
plot_history(history)

In [None]:
# evaluate and predict test set
results = net.evaluate(x=x_test, y=y_test)

predictions = net.predict(x=x_test)
show_predictions(x_test, y_test, predictions)

#### The training of the networks depend on many hyperparameters such as
- network architecture: #layers, #fmaps
- batch size, learning rate
- number and distribution of the training samples

#### You can play and see how these settings influence the learning curve.

![](example_learning_curves/lc_all.png)

### A2: Use early stopping to avoid overfitting

In [None]:
# early stopping is on of keras callback functions which can be applied during training procedure
from tensorflow.keras.callbacks import EarlyStopping

In [None]:
# define input shape, takes ~3min
net_input = tf.keras.Input(shape=(1, 512, 512), name='img')

# define activation function
activation = tf.keras.layers.Activation("sigmoid")

# create unet with parameters: input, # output channel, unet depth, # fmaps
net_output, receptive_field = unet.unet(net_input, 1, 2, 32, activation=activation)

# redefine the model to overwrite previous trainings
net_w_ea = tf.keras.Model(net_input, net_output, name='unet')

# specify the training configuration (optimizer, loss, metrics)
net_w_ea.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)]
)

# specify early stopping
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=30, restore_best_weights=True)
history_w_ea = net_w_ea.fit(x=x_train, y=y_train, batch_size=4, epochs=500, validation_data=(x_val, y_val), 
                     callbacks=[es])
print('Finished Training')

In [None]:
# plot loss and accuracy
plot_history(history_w_ea)

In [None]:
# evaluate and predict test set
results_w_ea = net_w_ea.evaluate(x=x_test, y=y_test)

predictions_w_ea = net_w_ea.predict(x=x_test)
show_predictions(x_test, y_test, predictions_w_ea)

### A3: Use a data generator to avoid overfitting

In [None]:
# As we have simulated data, we can use unlimited number of training examples
# image generator copied from https://github.com/jakeret/tf_unet
import image_gen

In [None]:
# define image shape
nx = 512
ny = 512

# create a wrapper generator which can be used in keras
def train_generator(batch_size):
    
    # init image generator with the following parameters:
    # nx, ny, cnt = 10, r_min = 5, r_max = 50, border = 92, sigma = 20, limit_num_samples = -1, binary = True
    generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20, r_min=10, r_max=25, binary=True)
    data_generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20, r_min=10, r_max=25, binary=True)
    batch_labels = np.zeros((batch_size, 1)) 
    while True:
        data, labels = data_generator(batch_size)
        yield data, labels

In [None]:
# define input shape, takes ~3min
net_input = tf.keras.Input(shape=(1, nx, ny), name='img')

# define activation function
activation = tf.keras.layers.Activation("sigmoid")

# create unet with parameters: input, # output channel, unet depth, # fmaps
net_output, receptive_field = unet.unet(net_input, 1, 2, 32, activation=activation)

# instantiate the model
net_w_gen = tf.keras.Model(net_input, net_output, name='unet')

net_w_gen.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)]
)
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=30, restore_best_weights=True)

# train the model by using the generator
history_w_gen = net_w_gen.fit_generator(
    generator=train_generator(4),
    steps_per_epoch=4,
    epochs=140,
    validation_data=(x_val, y_val),
    callbacks=[es]
)
print('Finished Training')

In [None]:
plot_history(history_w_gen)

In [None]:
# evaluate and predict test set
results_w_gen = net_w_gen.evaluate(x=x_test, y=y_test)

predictions_w_gen = net_w_gen.predict(x=x_test)
show_predictions(x_test, y_test, predictions_w_gen)