### Import Libraries

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Conv2D, Dropout, MaxPooling2D, Conv2DTranspose, concatenate
from tensorflow.keras import Sequential

import os
import numpy as np
from tqdm import tqdm
import random
import matplotlib.pyplot as plt

from skimage.io import imread, imshow
from skimage.transform import resize

import U-Net.UNET as unet

### Define Globals

In [None]:
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

### Create Model

In [None]:
# Input
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = Lambda(lambda x: x / 255)(inputs)

# Encoder
l1c1 = Conv2D(16, 3, activation='relu', kernel_initializer='he_normal', padding='same')(s)
l1d1 = Dropout(0.1)(l1c1)
l1c2 = Conv2D(16, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l1d1)
l1p = MaxPooling2D((2, 2))(l1c2)

l2c1 = Conv2D(32, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l1p)
l2d1 = Dropout(0.1)(l2c1)
l2c2 = Conv2D(32, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l2d1)
l2p = MaxPooling2D((2, 2))(l2c2)

l3c1 = Conv2D(64, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l2p)
l3d1 = Dropout(0.2)(l3c1)
l3c2 = Conv2D(64, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l3d1)
l3p = MaxPooling2D((2, 2))(l3c2)

l4c1 = Conv2D(128, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l3p)
l4d1 = Dropout(0.2)(l4c1)
l4c2 = Conv2D(128, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l4d1)
l4p = MaxPooling2D((2, 2))(l4c2)

# Middle
mc1 = Conv2D(256, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l4p)
md = Dropout(0.3)(mc1)
mc2 = Conv2D(256, 3, activation='relu', kernel_initializer='he_normal', padding='same')(md)

# Decoder
l4u = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(mc2)    
l4c = concatenate([l4u, l4c2])
l4c3 = Conv2D(128, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l4c)
l4d2 = Dropout(0.2)(l4c3)
l4c4 = Conv2D(128, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l4d2)

l3u = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(l4c4)
l3c = concatenate([l3u, l3c2])
l3c3 = Conv2D(64, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l3c)
l3d2 = Dropout(0.2)(l3c3)
l3c4 = Conv2D(64, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l3d2)

l2u = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(l3c4)
l2c = concatenate([l2u, l2c2])
l2c3 = Conv2D(32, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l2c)
l2d2 = Dropout(0.1)(l2c3)
l2c4 = Conv2D(32, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l2d2)

l1u = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(l2c4)
l1c = concatenate([l1u, l1c2])
l1c3 = Conv2D(16, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l1c)
l1d2 = Dropout(0.1)(l1c3)
l1c4 = Conv2D(16, 3, activation='relu', kernel_initializer='he_normal', padding='same')(l1d2)

# Output
outputs = Conv2D(1, (1, 1), activation='sigmoid')(l1c4)

# Create Model
model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

### Begin Training

Create tools for training

In [None]:
checkPointer = tf.keras.callbacks.ModelCheckpoint('U-Net.h5', 
                                                  verbose=1, 
                                                  save_best_only=True)

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                     patience=2, 
                                     verbose=1),

    tf.keras.callbacks.TensorBoard(log_dir='logs')
]

Get Data

In [None]:
TRAIN_PATH = 'data/stage1_train/'
TEST_PATH = 'data/stage1_test/'

trainIDs = next(os.walk(TRAIN_PATH))[1]
testIDs = next(os.walk(TEST_PATH))[1]

# Get training images

XTrain = np.zeros((len(trainIDs), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
YTrain = np.zeros((len(trainIDs), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)

print("Resizing training images and masks")

for n, id_ in tqdm(enumerate(trainIDs), total=len(trainIDs)):
    path = TRAIN_PATH + id_

    img = imread(path + '/images/' + id_ + '.png')[:, :, :IMG_CHANNELS]
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    XTrain[n] = img 

    mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)
    for mask_file in next(os.walk(path + '/masks/'))[2]:
        mask_ = imread(path + '/masks/' + mask_file)
        mask_ = np.expand_dims(
            resize(mask_, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True), 
            axis=-1
        )
        mask = np.maximum(mask, mask_)
    
    YTrain[n] = mask

# Get testing images
XTest = np.zeros((len(testIDs), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
sizes_test = []

print("Resizing test images")
for n, id_ in tqdm(enumerate(testIDs), total=len(testIDs)):
    path = TEST_PATH + id_

    img = imread(path + '/images/' + id_ + '.png')[:, :, :IMG_CHANNELS]
    sizes_test.append([img.shape[0], img.shape[1]])
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    XTest[n] = img

print('Done!')

Show 1 Random

In [None]:
# i = random.randint(0, len(trainIDs))
# imshow(XTrain[i])
# plt.show()
# imshow(np.squeeze(YTrain[i]))
# plt.show()

Train

In [None]:
results = model.fit(XTrain, YTrain, validation_split=0.1, batch_size=16, epochs=25, callbacks=callbacks)

### Test

In [None]:
i = random.randint(0, len(XTrain))

predsTrain = model.predict(XTrain[:int(XTrain.shape[0]*0.9)], verbose=1)
predsVal = model.predict(XTrain[int(XTrain.shape[0]*0.9):], verbose=1)
predsTest = model.predict(XTest, verbose=1)

predsTrain_t = (predsTrain > 0.5).astype(np.uint8)
predsVal_t = (predsVal > 0.5).astype(np.uint8)
predsTest_t = (predsTest > 0.5).astype(np.uint8)

# Sanity Check (Training Sample)
ix = random.randint(0, len(predsTrain_t))
imshow(XTrain[ix])
plt.show()
imshow(np.squeeze(YTrain[ix]))
plt.show()
imshow(np.squeeze(predsTrain_t[ix]))
plt.show()

# Sanity Check (Validation Sample)
ix = random.randint(0, len(predsVal_t))
imshow(XTrain[int(XTrain.shape[0]*0.9):][ix])
plt.show()
imshow(np.squeeze(YTrain[int(YTrain.shape[0]*0.9):][ix]))
plt.show()
imshow(np.squeeze(predsVal_t[ix]))
plt.show()