In [None]:
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, TensorBoard, LearningRateScheduler

import os, math, fnmatch
import numpy as np
from contextlib import redirect_stdout

from utils.data import trainGenerator, testGenerator, saveResult
from utils.helpers import get_label_info
from utils.customloss import weighted_categorical_crossentropy
from model.refinenet import build_refinenet

## Definitions 

In [None]:
# Define custom learning rate schedule
def step_decay(epoch):
    initial_lrate = 1e-5
    drop = 0.5
    epochs_drop = 5.0
    lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
    return lrate

In [None]:
# Define parameters
dataset_basepath = '/data/Cityscapes' 
class_dict = 'class_dict.csv'

resnet_weights = 'resnet101_weights_tf.h5'

input_shape = (768,768,3)
batch_size = 1

data_gen_args = dict(rotation_range=0.1,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')

save_summary = True

# Import classes from csv file
class_names_list, mask_colors, num_class, class_names_string = get_label_info(os.path.join(dataset_basepath,class_dict))

# Define custom loss function to ignore void class (https://github.com/keras-team/keras/issues/6261)
# Assuming last class is void
def ignore_unknown_xentropy(ytrue, ypred):
    return (1-ytrue[:, :, :, num_class-1])*categorical_crossentropy(ytrue, ypred)

# Data generators for training
myTrainGenerator = trainGenerator(batch_size,
                                  os.path.join(dataset_basepath,'training'),
                                  'images',
                                  'labels',
                                  num_class,
                                  input_shape,
                                  data_gen_args,
                                  mask_colors = mask_colors)
myValGenerator = trainGenerator(batch_size,
                                os.path.join(dataset_basepath,'validation'),
                                'images',
                                'labels',
                                num_class,
                                input_shape,
                                data_gen_args,
                                mask_colors = mask_colors)

# Define callbacks
model_checkpoint = ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.hdf5',
                                   monitor = 'val_loss',
                                   verbose = 1,
                                   save_best_only = True)

tbCallBack = TensorBoard(log_dir='./log', histogram_freq=0,
                         write_graph=True,
                         write_grads=True,
                         batch_size=batch_size,
                         write_images=True)

lrate = LearningRateScheduler(step_decay)

## Build model

In [None]:
# Build and compile RefineNet
model = build_refinenet(input_shape, num_class, resnet_weights, True) # set to False to unfreeze frontend layers
model.compile(optimizer = Adam(lr = 1e-5), loss = ignore_unknown_xentropy, metrics = ['accuracy'])

if save_summary:
    with open('RefineNet_summary.txt', 'w') as f:
        with redirect_stdout(f):
            model.summary()

## Training

In [None]:
# Load previous weights if available
model.load_weights('checkpoints/weights.01-0.30.hdf5')

In [None]:
# Start training
model.fit_generator(myTrainGenerator,
                    steps_per_epoch = 2975,
                    validation_data = myValGenerator,
                    validation_steps = 30,
                    epochs = 50,
                    callbacks = [model_checkpoint, tbCallBack, lrate])

## Inference

In [None]:
out_dir = 'img'
out_size = (1024,2048,3)

myTestGenerator = testGenerator(os.path.join(dataset_basepath,'testing/images'), input_shape, out_dir)
results = model.predict_generator(myTestGenerator, 30, verbose=1)
saveResult(results, out_dir, out_size, mask_colors)