# 1. Dependency libraries

In [None]:
from model.seg_hrnet import seg_hrnet
from utils.loss import *
from utils.metrics import *
from dataloaders.generater import *
import os
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import SGD

# 2. Params

In [None]:
# network params
BatchSize = 1
NumChannels = 3
ImgHeight = 512
ImgWidth = 512
NumClass = 1

# training params
GPUs = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = GPUs
Optimizer = 'Adam'  # SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
Loss = dice_loss
Metrics = ['accuracy', iou]
NumEpochs = 100
Patience = 10

# data params
TrainImageDir = 'D:/DATA/AerialImageDataset/data/train/images/'
ValImageDir = 'D:/DATA/AerialImageDataset/data/train/gt/'

# visualization params
metric_list = ['acc', 'iou']

In [None]:
model = seg_hrnet(BatchSize, ImgHeight, ImgWidth, NumChannels, NumClass)
model.summary()
model.compile(optimizer=Optimizer, loss=Loss, metrics=Metrics)

In [None]:
model_path = "seg_hrnet-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}-{val_iou:.4f}.hdf5"
model_checkpoint = ModelCheckpoint(model_path, monitor='val_iou', mode='max', verbose=1, save_best_only=False)
early_stop = EarlyStopping(monitor='val_iou', mode='max', patience=Patience)
check_point_list = [model_checkpoint, early_stop]

In [None]:
train_paths, val_paths = get_data_paths(TrainImageDir, ValImageDir)
train_steps = len(train_paths) // BatchSize
val_steps = len(val_paths) // BatchSize

In [None]:
result = model.fit_generator(
    generator=batch_generator(train_paths, BatchSize),
    steps_per_epoch=train_steps,
    epochs=NumEpochs,
    verbose=1,
    validation_data=batch_generator(val_paths, BatchSize),
    validation_steps=val_steps,
    callbacks=check_point_list)

In [None]:
plt.figure()
for metric ioun metric_list:
    plt.plot(result.epoch, result.history[metric], label=metric)
    plt.scatter(result.epoch, result.history[metric], marker='*')
    val_metric = 'val_' + metric
    plt.plot(result.epoch, result.history[val_metric], label=val_metric)
    plt.scatter(result.epoch, result.history[val_metric], marker='*')
plt.legend(loc='under right')
plt.show()

plt.figure()
plt.plot(result.epoch, result.history['loss'], label="loss")
plt.plot(result.epoch, result.history['val_loss'], label="val_loss")
plt.scatter(result.epoch, result.history['loss'], marker='*')
plt.scatter(result.epoch, result.history['val_loss'], marker='*')
plt.legend(loc='upper right')
plt.show()