# Run network model

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("/home/ubuntu/roaddetection/")
sys.path.append("/mnt/hd_internal/hh/projects_DS/road_detection/roaddetection")
from src.models.data import trainGenerator
from src.models.network_models import unet, unet_var
from src.data.utils import get_tile_prefix
from src.models.metrics_img import auc_pr
import matplotlib.pyplot as plt
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, LambdaCallback
from keras.optimizers import Adam
#from pathlib import Path
import os, shutil, platform
%matplotlib inline

### Define directories

In [None]:
dirs = []
data_dir = "../../data"
model_dir = "../../models/UNet"
report_dir = "../../reports"

train_dir = os.path.join(data_dir, "train")
dirs.append(train_dir)

train_partial_dir = os.path.join(data_dir, "train_partial")
dirs.append(train_partial_dir)

validation_dir = os.path.join(data_dir, "validate")
dirs.append(validation_dir)

test_dir = os.path.join(data_dir, "test")
dirs.append(test_dir)

## User settings

In [None]:
# ------------- image characteristics and augmentation -----------------------------
# size of tiles
target_size = (512,512)
# input arguments to Keras' ImageDataGenerator
data_gen_args = dict(
                    data_format="channels_last",
                    horizontal_flip=True, 
                    vertical_flip=True
 )
# directory into which to place *training* images from ImageDataGenerator for inspection;
# default should be None because this slows things down
imgdatagen_dir = None
#imgdatagen_dir = data_dir + '/imgdatagenerator'

#--------------- network weights ----------------------------------------------------
# path to & filename of pre-trained model to use - set to None if you want to start from scratch
pretrained_model_fn = model_dir + '/models_unet_borneo_and_harz_05_09_16_22.hdf5'
pretrained_model_fn = model_dir + '/unet_var_test.hdf5'
pretrained_model_fn = None

# path to & filename of model to save
trained_model_fn = model_dir + '/unet_test.hdf5'

#--------------- training details / hyperparameters -----------------------------------
# batch size
batch_size = 1
# steps per epoch, should correspond to [number of training images] / batch size
steps_per_epoch = 600 // batch_size
# number of epochs
epochs = 10
# number of steps on validation set
validation_steps = 60
# self-explanatory variables:
optimizer = Adam(lr=2e-4)
loss = 'binary_crossentropy'
loss_weights = None
metrics = ['accuracy', auc_pr]

### Count image tiles in train/validation/test directories

In [None]:
for directory in dirs:
    for file_type in ["sat", "map", "sat_rgb"]:
        target = os.path.join(directory, file_type)
        print(target, ":", len(os.listdir(target)))

print("Done.")

### Set up ImageDataGenerators for training and validation sets

In [None]:
train_gen = trainGenerator(batch_size, data_dir + '/train_partial','sat','map',
                        data_gen_args, save_to_dir = imgdatagen_dir, image_color_mode="rgba", target_size=target_size)

validation_gen = trainGenerator(batch_size, data_dir + '/validate','sat','map',
                        data_gen_args, save_to_dir = None, image_color_mode="rgba", target_size=target_size)

### Define model, compile, show summary, possibly load weights, define callbacks (including checkpoints)

In [None]:
model = unet_var()
#model = unet()
model.compile(optimizer=optimizer,
              loss=loss,
              loss_weights=loss_weights,
              metrics=metrics)
model.summary()
if (pretrained_model_fn):
    model.load_weights(pretrained_model_fn)
model_checkpoint = ModelCheckpoint(trained_model_fn, monitor='loss',verbose=1, save_best_only=True)

#Stop training if loss doesn't improve for 2 consecutive epochs
early_stop = EarlyStopping(monitor='loss', min_delta=0, patience=5, verbose=1, mode='auto', baseline=None)

### Run training

In [None]:
import logging

def get_logger():
    log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    logging.basicConfig(level=logging.INFO, format=log_fmt)
    formatter = logging.Formatter(log_fmt)
    fh = logging.FileHandler('../../logs/unet-2.log')
    fh.setFormatter(formatter)
    logger = logging.getLogger(__name__)
    logger.addHandler(fh)
    return logger

In [None]:
logger = get_logger()
logging_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: logger.info({'epoch': epoch, 'logs': logs})
)

In [None]:
history = model.fit_generator(train_gen,
                              steps_per_epoch=steps_per_epoch,
                              epochs=epochs,
                              callbacks=[model_checkpoint, early_stop, logging_callback],
                              validation_data=validation_gen,
                              validation_steps=validation_steps
                             )

In [None]:
def plot_history(history):
    plt.plot(history["acc"], label="acc")
    plt.plot(history["val_acc"], label="val_acc")
    plt.legend()
    plt.show()
    plt.close()
    
    plt.plot(history["loss"], label="loss")
    plt.plot(history["val_loss"], label="val_loss")
    plt.legend()
    plt.show()
    plt.close()

    plt.plot(history["auc_pr"], label="auc_pr")
    plt.plot(history["val_auc_pr"], label="val_auc_pr")

    plt.legend()
    plt.show()
    plt.savefig("../../logs/unet_borneo_and_harz_05_09_11_11.jpg")
    plt.close()

plot_history(history.history)