# Training U-Net for BAGLS dataset

Load needed components

In [1]:
import numpy as np
import keras
from keras.optimizers import Adam
import tensorflow as tf
import random

# Splitting train and validation data
from sklearn.model_selection import train_test_split

# Loss and evaluation metric
from segmentation_models.losses import dice_loss
from segmentation_models.metrics import iou_score

Using TensorFlow backend.


## Data pre-processing and dynamic feeding

We use a [data generator](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly) to provide images on the fly during training and to parallelize the data pre-processing (e.g. image augmentation).

In [2]:
from DataGenerator import DataGenerator

## Custom callbacks

We use custom callbacks, in detail

- Cyclic learning rate ($10^{-3}$ to $10^{-6}$)
- Saving model if validation IoU is greater than the previous one
- TQDM-based progress bar

In [3]:
from Callbacks import get_callbacks

## Load neural network components and build U-Net

In [4]:
from Unet import Unet

In [5]:
model = Unet()
model.summary()

Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, None, 1 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 6 576         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 6 0           batch_normalization_1[0][0]      
_____________________________________

__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, None, None, 5 2048        conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, None, None, 5 0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, None, None, 2 1179648     activation_11[0][0]              
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, None, None, 2 1024        conv2d_12[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, None, None, 2 0           batch_normalization_12[0][0]     
__________

Define Images

In [6]:
# Location of the training data and path to the saved model
TRAINING_PATH = "C:/BAGLS/training/"
MODEL_PATH = "Unet.h5"

# All training images
N = 55750 

train_imgs = [TRAINING_PATH + str(i) + ".png" for i in range(N)]
train_segs = [TRAINING_PATH + str(i) + "_seg.png" for i in range(N)]

## Training

In [7]:
# Set random seed for reproducible training
SEED = 42
np.random.seed(SEED)
tf.set_random_seed(SEED)
random.seed(SEED)

# Define training parameters
BATCH_SIZE = 16 # adjust for your graphics card
LEARNING_RATE = 10e-3
EPOCHS = 25

# Hard split of training and validation data
X, X_val, y, y_val = train_test_split(train_imgs,
                              train_segs,
                              test_size=0.05,
                              random_state=SEED)

# Augment and shuffle training data
train_gen = DataGenerator(X,
                          y,
                          BATCH_SIZE, 
                          augment=True, 
                          shuffle=True)

# Do not augment and shuffle validation data
val_gen = DataGenerator(X_val, 
                        y_val, 
                        BATCH_SIZE, 
                        augment=False, 
                        shuffle=False)

# Compile model with optimizer (Adam with Cyclic Learning Rate)
#  and DICE loss
model.compile(optimizer=Adam(),
              loss = dice_loss,
              metrics = ['accuracy',
                         iou_score])

# Create custom callbacks for saving model and cyclic learning rate
callbacks = get_callbacks(MODEL_PATH)

# Fit the neural network
history = model.fit_generator(
            # Training generator (with shuffling and augmentation)
            generator=train_gen,
            # Validation generator (no shuffling and augmentation)
            validation_data=val_gen,
            # Train for EPOCHS 
            epochs=EPOCHS, 
            # No output
            verbose=0,
            # Multiprocessing for data pre-processing
            use_multiprocessing=True, 
            # How many cores are utilized in multiprocessing, adjust for your CPU cores
            workers=10, 
            # Batches in memory
            max_queue_size=32,
            # Custom Callbacks
            callbacks=callbacks)

Instructions for updating:
Use tf.cast instead.


HBox(children=(IntProgress(value=0, description='Training', max=25, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=3310, style=ProgressStyle(description_width='in…


Epoch 00001: val_iou_score improved from -inf to 0.67093, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 1', max=3310, style=ProgressStyle(description_width='in…


Epoch 00002: val_iou_score improved from 0.67093 to 0.78291, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 2', max=3310, style=ProgressStyle(description_width='in…


Epoch 00003: val_iou_score did not improve from 0.78291


HBox(children=(IntProgress(value=0, description='Epoch 3', max=3310, style=ProgressStyle(description_width='in…


Epoch 00004: val_iou_score did not improve from 0.78291


HBox(children=(IntProgress(value=0, description='Epoch 4', max=3310, style=ProgressStyle(description_width='in…


Epoch 00005: val_iou_score did not improve from 0.78291


HBox(children=(IntProgress(value=0, description='Epoch 5', max=3310, style=ProgressStyle(description_width='in…


Epoch 00006: val_iou_score did not improve from 0.78291


HBox(children=(IntProgress(value=0, description='Epoch 6', max=3310, style=ProgressStyle(description_width='in…


Epoch 00007: val_iou_score did not improve from 0.78291


HBox(children=(IntProgress(value=0, description='Epoch 7', max=3310, style=ProgressStyle(description_width='in…


Epoch 00008: val_iou_score did not improve from 0.78291


HBox(children=(IntProgress(value=0, description='Epoch 8', max=3310, style=ProgressStyle(description_width='in…


Epoch 00009: val_iou_score improved from 0.78291 to 0.81260, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 9', max=3310, style=ProgressStyle(description_width='in…


Epoch 00010: val_iou_score did not improve from 0.81260


HBox(children=(IntProgress(value=0, description='Epoch 10', max=3310, style=ProgressStyle(description_width='i…


Epoch 00011: val_iou_score improved from 0.81260 to 0.81363, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 11', max=3310, style=ProgressStyle(description_width='i…


Epoch 00012: val_iou_score did not improve from 0.81363


HBox(children=(IntProgress(value=0, description='Epoch 12', max=3310, style=ProgressStyle(description_width='i…


Epoch 00013: val_iou_score did not improve from 0.81363


HBox(children=(IntProgress(value=0, description='Epoch 13', max=3310, style=ProgressStyle(description_width='i…


Epoch 00014: val_iou_score did not improve from 0.81363


HBox(children=(IntProgress(value=0, description='Epoch 14', max=3310, style=ProgressStyle(description_width='i…


Epoch 00015: val_iou_score did not improve from 0.81363


HBox(children=(IntProgress(value=0, description='Epoch 15', max=3310, style=ProgressStyle(description_width='i…


Epoch 00016: val_iou_score did not improve from 0.81363


HBox(children=(IntProgress(value=0, description='Epoch 16', max=3310, style=ProgressStyle(description_width='i…


Epoch 00017: val_iou_score improved from 0.81363 to 0.81493, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 17', max=3310, style=ProgressStyle(description_width='i…


Epoch 00018: val_iou_score improved from 0.81493 to 0.81793, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 18', max=3310, style=ProgressStyle(description_width='i…


Epoch 00019: val_iou_score did not improve from 0.81793


HBox(children=(IntProgress(value=0, description='Epoch 19', max=3310, style=ProgressStyle(description_width='i…


Epoch 00020: val_iou_score improved from 0.81793 to 0.83068, saving model to Unet.h5


HBox(children=(IntProgress(value=0, description='Epoch 20', max=3310, style=ProgressStyle(description_width='i…


Epoch 00021: val_iou_score did not improve from 0.83068


HBox(children=(IntProgress(value=0, description='Epoch 21', max=3310, style=ProgressStyle(description_width='i…


Epoch 00022: val_iou_score did not improve from 0.83068


HBox(children=(IntProgress(value=0, description='Epoch 22', max=3310, style=ProgressStyle(description_width='i…


Epoch 00023: val_iou_score did not improve from 0.83068


HBox(children=(IntProgress(value=0, description='Epoch 23', max=3310, style=ProgressStyle(description_width='i…


Epoch 00024: val_iou_score did not improve from 0.83068


HBox(children=(IntProgress(value=0, description='Epoch 24', max=3310, style=ProgressStyle(description_width='i…


Epoch 00025: val_iou_score did not improve from 0.83068

