In [None]:
import numpy as np
import matplotlib.pyplot as plt
import keras
import os
import time
import pickle

from sl_model import SL512, DSODSL512
from ssd_data import InputGenerator
from sl_utils import PriorUtil
from sl_training import SegLinkLoss, SegLinkFocalLoss

from utils.training import Logger, LearningRateDecay
from utils.model import load_weights, calc_memory_usage

### Data

In [None]:
from data_synthtext import GTUtility

file_name = 'gt_util_synthtext_seglink.pkl'
with open(file_name, 'rb') as f:
    gt_util = pickle.load(f)
gt_util_train, gt_util_val = gt_util.split(0.9)

print(gt_util_train)

### Model

In [None]:
# SegLink
model = SL512()
weights_path = './models/ssd512_voc_weights_fixed.hdf5'
#weights_path = '~/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
batch_size = 24
experiment = 'sl512_synthtext'
#experiment = 'sl512fl_synthtext'

In [None]:
# SegLink + DenseNet
model = DSODSL512()
#model = DSODSL512(activation='leaky_relu')
weights_path = None
batch_size = 6
experiment = 'dsodsl512_synthtext'

In [None]:
# SegLink + ResNet
from ssd_model_resnet import SL512_resnet
model = SL512_resnet()
weights_path = None
batch_size = 10
experiment = 'sl512_resnet_synthtext'

In [None]:
if weights_path is not None:
    if weights_path.find('ssd512') > -1:
        layer_list = [
            'conv1_1', 'conv1_2',
            'conv2_1', 'conv2_2',
            'conv3_1', 'conv3_2', 'conv3_3',
            'conv4_1', 'conv4_2', 'conv4_3',
            'conv5_1', 'conv5_2', 'conv5_3',
            'fc6', 'fc7',
            'conv6_1', 'conv6_2',
            'conv7_1', 'conv7_2',
            'conv8_1', 'conv8_2',
            'conv9_1', 'conv9_2',
        ]
        freeze = [
            'conv1_1', 'conv1_2',
            'conv2_1', 'conv2_2',
            'conv3_1', 'conv3_2', 'conv3_3',
            #'conv4_1', 'conv4_2', 'conv4_3',
            #'conv5_1', 'conv5_2', 'conv5_3',
        ]
        
        load_weights(model, weights_path, layer_list)
        for layer in model.layers:
            layer.trainable = not layer.name in freeze
    else:
        load_weights(model, weights_path)

prior_util = PriorUtil(model)

### Training

In [None]:
epochs = 100
initial_epoch = 0

gen_train = InputGenerator(gt_util_train, prior_util, batch_size, model.image_size, augmentation=False)
gen_val = InputGenerator(gt_util_val, prior_util, batch_size, model.image_size, augmentation=False)

checkdir = './checkpoints/' + time.strftime('%Y%m%d%H%M') + '_' + experiment
if not os.path.exists(checkdir):
    os.makedirs(checkdir)

with open(checkdir+'/source.py','wb') as f:
    source = ''.join(['# In[%i]\n%s\n\n' % (i, In[i]) for i in range(len(In))])
    f.write(source.encode())

#optim = keras.optimizers.SGD(lr=1e-3, momentum=0.9, decay=0, nesterov=True)
optim = keras.optimizers.Adam(lr=1e-3, beta_1=0.9, beta_2=0.999, epsilon=0.001, decay=0.0)

# weight decay
regularizer = keras.regularizers.l2(5e-4) # None if disabled
#regularizer = None
for l in model.layers:
    if l.__class__.__name__.startswith('Conv'):
        l.kernel_regularizer = regularizer

loss = SegLinkLoss(lambda_offsets=1.0, lambda_links=1.0, neg_pos_ratio=3.0)
#loss = SegLinkFocalLoss()
#loss = SegLinkFocalLoss(lambda_segments=1.0, lambda_offsets=1.0, lambda_links=1.0)
#loss = SegLinkFocalLoss(gamma_segments=3, gamma_links=3)

model.compile(optimizer=optim, loss=loss.compute, metrics=loss.metrics)

history = model.fit_generator(
        gen_train.generate(), 
        steps_per_epoch=gen_train.num_batches, 
        epochs=epochs, 
        verbose=1, 
        callbacks=[
            keras.callbacks.ModelCheckpoint(checkdir+'/weights.{epoch:03d}.h5', verbose=1, save_weights_only=True),
            Logger(checkdir),
            #LearningRateDecay()
        ], 
        validation_data=gen_val.generate(), 
        validation_steps=gen_val.num_batches,
        class_weight=None,
        max_queue_size=1, 
        workers=1, 
        #use_multiprocessing=False, 
        initial_epoch=initial_epoch, 
        #pickle_safe=False, # will use threading instead of multiprocessing, which is lighter on memory use but slower
        )