# Training latent space restricted U-Net on BAGLS dataset

Load needed components

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import tensorflow as tf
import random
import json

# Splitting train and validation data
from sklearn.model_selection import train_test_split

# Loss and evaluation metric
from segmentation_models.losses import dice_loss
from segmentation_models.metrics import iou_score
from tensorflow.keras.optimizers import Adam

## Data pre-processing and dynamic feeding

We use a [data generator](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly) to provide images on the fly during training and to parallelize the data pre-processing (e.g. image augmentation). We resized the `BAGLS` dataset to 512$\times$256 px prior to training

In [None]:
from DataGenerator import DataGenerator512x256

## Load neural network components and build U-Net

In [None]:
from Unet import Unet

Define Images

In [None]:
# Location of the training data and path to the saved model
TRAINING_PATH = "C:/BAGLS/training_512x256/"

# All training images
N = 55750 

train_imgs = [TRAINING_PATH + str(i) + ".png" for i in range(N)]
train_segs = [TRAINING_PATH + str(i) + "_seg.png" for i in range(N)]

# Train 

In [None]:
# Define model and check its summary
model = Unet()
model.summary()

In [None]:
# Set random seed for reproducible training
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

BATCH_SIZE = 32 # adjust for your graphics card

LEARNING_RATE = 1e-3
EPOCHS = 25

# Hard split of training and validation data
X, X_val, y, y_val = train_test_split(train_imgs,
                              train_segs,
                              test_size=0.05,
                              random_state=SEED)

# Augment and shuffle training data
train_gen = DataGenerator512x256(X,
                          y,
                          BATCH_SIZE, 
                          augment=True, 
                          shuffle=True)

# Do not augment and shuffle validation data
val_gen = DataGenerator512x256(X_val, 
                        y_val, 
                        BATCH_SIZE, 
                        augment=False, 
                        shuffle=False)


# Compile model with optimizer
#  and DICE loss
model.compile(optimizer=Adam(LEARNING_RATE),
              loss = dice_loss,
              metrics = ['accuracy',
                         iou_score])

history = model.fit(
        # Training generator (with shuffling and augmentation)
        train_gen,
        # Validation generator (no shuffling and augmentation)
        validation_data=val_gen,
        # Train for EPOCHS 
        epochs=EPOCHS)