# Train a Multisegmentaion Network
by DevNesh

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras.backend as K
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras.losses import *
import skimage.io as io
import skimage.transform as tr
import skimage.color
import dask.array as da
from glob import glob
from dask.array.image import imread
from skimage import img_as_ubyte

from helper import * 

## Load Data

In [None]:
# Reads more that one mask as ground truth 
def read_masks(path, size):
    imgs = []
    paths = glob(path)
    index = 0
    for p in paths:
        
        # mask1
        mask1 = read_img(p, (224,224,1))
        mask1 = np.array(mask1)
        mask1[mask1[:,:,0] > 0.001] = 1
        
        # mask2
        p2 = p.replace("masks_01","masks_02")
        mask2 = read_img(p2, (224,224,1))
        mask2 = np.array(mask2)
        mask2[mask2[:,:,0] > 0.001] = 1
        
        # mask3
        p3 = p.replace("masks_01","masks_03")
        mask3 = read_img(p3, (224,224,1))
        mask3 = np.array(mask3)
        mask3[mask3[:,:,0] > 0.001] = 1
        
        # concatenation
        masks = np.concatenate([mask1, mask2, mask3], axis=2)
        
        imgs.append(masks)
        index += 1
        if (index % 200 == 0):
            print(index)
    return np.array(imgs)

# Saves the data into a variable (x = input, y = masks / ground truth)
x = None
x = read_imgs('/home/dan/Desktop/combined_masks/images/data/*.png', (224,224,1))
y = None
y = read_masks('/home/dan/Desktop/combined_masks/masks_01/data/*.png', (224,224,1))

In [None]:
# Plots the masks in jupyter notebook for comparison
i = 4

# Input 
plt.imshow([i][:,:,0], cmap = 'gray')
plt.show()

# Ground Truth Mask1
plt.imshow(y[i][:,:,0], cmap = 'gray')
plt.show()

# Ground Truth Mask2 
plt.imshow(y[i][:,:,1], cmap = 'gray')
plt.show()

# Ground Truth Mask3
plt.imshow(y[i][:,:,2], cmap = 'gray')
plt.show()

## Load Model

In [None]:
#load trained model 
from keras.models import load_model
model = load_model('modelsave2.h5', custom_objects={'iou_loss': f1_loss})

model.summary()

In [None]:
from unet import UNet
model = None
model = UNet((224,224,1), 3, 16, 4, 2.0)
model.summary()

In [None]:
# create callbacks
earlyStop = EarlyStopping(monitor='val_loss', patience = 5)
checkpoint = ModelCheckpoint('training_multi_best5.h5', save_best_only=True)

## Train Model

In [None]:
model.compile(optimizer=Adam(lr=0.0001), loss=f1_loss, metrics=[iou_loss, precision, error,recall])

In [None]:
train = 2450 # = 80% of all given data
result = model.fit(x[:train], y[:train], batch_size=32, epochs=20,
         validation_data=(x[train:], y[train:]), shuffle=True, callbacks=[earlyStop, checkpoint])

## Show Results

In [None]:
# list data from history
print(result.history.keys())

# plot graph for loss 
plt.plot(result.history['loss'])
plt.plot(result.history['val_loss'])
plt.title('model loss') # name of graph
plt.ylabel('loss')  #name of y-axis
plt.xlabel('epoch') #name of x-axis
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# list all information
print(result.history)

## Save Model

In [None]:
model.save('training_multi5.h5')

## Visualize Predicitions

In [None]:
testdata = read_imgs('/home/dan/Desktop/multipredict/test/images/data/*.png', (224,224,1))

In [None]:
# makes a prediction for the whole dataset
pred = model.predict(testdata, verbose=1)

In [None]:
# Plots the prediction in jupyter notebook for comparison
i = 13

# Input Picture
plt.imshow(testdata[i, ..., 0], cmap='gray')
plt.axis('off')
plt.show()

plt.imshow(pred[i])
plt.axis('off')
plt.show()

# Prediction Dim 0
plt.imshow(pred[i][:,:,0], cmap = 'gray')
plt.axis('off')
plt.show()

# Prediction Dim 1
plt.imshow(pred[i][:,:,1], cmap = 'gray')
plt.axis('off')
plt.show()

# Prediction Dim 1
plt.imshow(pred[i][:,:,2], cmap = 'gray')
plt.axis('off')
plt.show()