# Training U-Net for BAGLS dataset

Load needed components

In [None]:
import numpy as np
import keras
from keras.optimizers import Adam
import tensorflow as tf
import random

# Splitting train and validation data
from sklearn.model_selection import train_test_split

# Loss and evaluation metric
from segmentation_models.losses import dice_loss
from segmentation_models.metrics import iou_score

In [None]:
# Location of the training data and path to store the saved model
TRAINING_PATH = "C:/BAGLS/training/"
MODEL_PATH = "Unet.h5"

## Data pre-processing and dynamic feeding

We use a [data generator](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly) to load images online during training and to parallelize the data pre-processing (e.g. image augmentation).

In [None]:
from Utils.DataGenerator import DataGenerator

## Custom callbacks

We use custom callbacks, in detail

- Cyclic learning rate ($10^{-3}$ to $10^{-6}$)
- Saving model if validation IoU is greater than the previous one
- TQDM-based progress bar

In [None]:
from Utils.Callbacks import get_callbacks

## Load neural network components and build U-Net

In [None]:
from Utils.Unet import Unet

In [None]:
model = Unet()
model.summary()

Load Images

In [None]:
# All training images
N = 55750 

train_imgs = [TRAINING_PATH + str(i) + ".png" for i in range(N)]
train_segs = [TRAINING_PATH + str(i) + "_seg.png" for i in range(N)]

## Training

In [None]:
# Set random seed for reproducible training
SEED = 42
np.random.seed(SEED)
tf.set_random_seed(SEED)
random.seed(SEED)

# Define training parameters
BATCH_SIZE = 16 # adjust for your graphics card
LEARNING_RATE = 10e-3
EPOCHS = 25

# Hard split of training and validation data
X, X_val, y, y_val = train_test_split(train_imgs,
                              train_segs,
                              test_size=0.05,
                              random_state=SEED)

# Augment and shuffle training data
train_gen = DataGenerator(X,
                          y,
                          BATCH_SIZE, 
                          augment=True, 
                          shuffle=True)

# Do not augment and shuffle validation data
val_gen = DataGenerator(X_val, 
                        y_val, 
                        BATCH_SIZE, 
                        augment=False, 
                        shuffle=False)

# Compile model with optimizer (Adam with Cyclic Learning Rate)
#  and DICE loss
model.compile(optimizer=Adam(),
              loss = dice_loss,
              metrics = ['accuracy',
                         iou_score])

# Create custom callbacks for saving model and cyclic learning rate
callbacks = get_callbacks(MODEL_PATH)

# Fit the neural network
history = model.fit_generator(
            # Training generator (with shuffling and augmentation)
            generator=train_gen,
            # Validation generator (no shuffling and augmentation)
            validation_data=val_gen,
            # Train for EPOCHS 
            epochs=EPOCHS, 
            # No output
            verbose=0,
            # Multiprocessing for data pre-processing
            use_multiprocessing=True, 
            # How many cores are utilized in multiprocessing, adjust for your CPU cores
            workers=10, 
            # Batches in memory
            max_queue_size=32,
            # Custom Callbacks
            callbacks=callbacks)