# Train a big dataset
by DevNesh

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras.backend as K
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras.losses import *
from keras.preprocessing.image import ImageDataGenerator
import skimage.io as io
import skimage.transform as tr
import skimage.color
from sklearn.metrics import classification_report
from glob import glob

## Own Scripts
from helper import * 
from loss_metrics import *
from unet import *

## Settings For Training

In [None]:
# initialise train data
train_size = 18300
val_size = 3922
test_size = 3922
batch_size = 32
dsPath = '/home/dan/Desktop/Datenset_Block3'
epochs = 40

## Load Data

In [None]:
# create generators for training 

trainInputPath = dsPath + '/train/images'
trainOutputPath = dsPath + '/train/masks'

valInputPath = dsPath + '/validate/images'
valOutputPath = dsPath + '/validate/masks'

trainGen = loadData(trainInputPath, trainOutputPath, batch_size)
valGen   = loadData(valInputPath, valOutputPath, batch_size)

## Load Model

In [None]:
# initialise new model
model = None
model = UNet((224,224,1), 1, 16, 5, 2.0, batchnorm = True)
model.summary()

In [None]:
# load trained model  
from keras.models import load_model 
model = load_model('ds_step08.h5', custom_objects={'iou_loss': iou_loss})
model.summary()

In [None]:
# create callbacks
earlyStop = EarlyStopping(monitor='val_loss', patience = 5)
checkpoint = ModelCheckpoint(dsPath + '/results/training_01_best.h5', save_best_only=True)

## Train Model

In [None]:
# compile the model
model.compile(optimizer=Adam(lr=0.0001), loss=iou_loss, metrics=[f1, iou, precision, recall, error])

In [None]:
# train and validate the model 
result = model.fit_generator(trainGen, steps_per_epoch= train_size/batch_size, epochs=epochs, validation_data=valGen, validation_steps=val_size/batch_size, verbose=1, shuffle=True, callbacks=[earlyStop, checkpoint])

## Show Validation Graph

In [None]:
# list data from history
print(result.history.keys())

# plot graph for loss 
plt.plot(result.history['loss'])
plt.plot(result.history['val_loss'])
plt.title('model loss') # name of graph
plt.ylabel('loss')  #name of y-axis
plt.xlabel('epoch') #name of x-axis
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# list all information
print(result.history)

## Save Model

In [None]:
model.save(dsPath + '/results/training_01.h5')