# Train a small dataset
by DevNesh

In [None]:
# Import statements
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras.backend as K
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras.losses import *
import skimage.io as io
import skimage.transform as tr
import skimage.color
import dask.array as da
from glob import glob
from dask.array.image import imread

import helper
import loss_metrics

## Load Data

In [None]:
# Saves the data into a variable (x = input, y = masks / ground truth)
x = read_imgs('/home/dan/Desktop/Neural Network/data/dataset1_1/images/data/*.png', (224,224,1))
y = read_imgs('/home/dan/Desktop/Neural Network/data/dataset1_1/masks/data/*.png', (224,224,1))

In [None]:
# Plots the images in jupyter notebook for checking
i = 95
print(x[i, ..., 0].shape)

# Input
plt.imshow(x[i, ..., 0], cmap = 'gray')
plt.show()

# Ground Truth
plt.imshow(y[i, ..., 0], cmap='gray')
plt.show()

## Load Model

In [None]:
# loads an saved model 
from keras.models import load_model
model = load_model('modelsave2.h5', custom_objects={'iou_loss': iou_loss})
model.summary()

In [None]:
# defining the U-Net model 
from unet import UNet
model = None
model = UNet((224,224,1), 1, 16, 4, 2.0)
model.summary()

In [None]:
# create callbacks
earlyStop = EarlyStopping(monitor='val_loss', patience = 5)
checkpoint = ModelCheckpoint('training_01.h5', save_best_only=True)

## Train Model

In [None]:
# compile the model, learnng rate and loss function have to be set
model.compile(optimizer=Adam(lr=0.0001), loss=iou_loss, metrics=['accuracy'])

In [None]:
# train the model
train = 2043 # = 80% of all given data  
result = model.fit(x[:train], y[:train], batch_size=32, epochs=10, validation_data=(x[train:], y[train:]), shuffle=True, callbacks=[earlyStop, checkpoint])

## Show Results

In [None]:
# list data from history
print(result.history.keys())

# plot graph for loss 
plt.plot(result.history['loss'])
plt.plot(result.history['val_loss'])
plt.title('model loss') # name of graph
plt.ylabel('loss')  #name of y-axis
plt.xlabel('epoch') #name of x-axis
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# list all information
print(result.history)

## Save Model

In [None]:
# saves the model
model.save('ds1_1_small.h5')