## UNet for image segmentation

## step 1: loading required packages

In [None]:
# set seeds to ensure repeatability of results
from numpy.random import seed
seed(101)

import pandas as pd
import numpy as np
import os
# import cv2
import tensorflow as tf
tf.random.set_seed(101)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Conv2D, MaxPool2D, Activation
from tensorflow.keras.layers import concatenate, BatchNormalization, Conv2DTranspose

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.regularizers import l2

import matplotlib
import matplotlib.pyplot as plt
# %matplotlib inline
import skimage
from skimage.io import imread, imshow
from skimage.transform import resize
from skimage.measure import label, regionprops

# Don't Show Warning Messages
import warnings
warnings.filterwarnings('ignore')

In [None]:
print("numpy versions is: ", np.__version__)
print("pandas versions is: ", pd.__version__)
print("tensorflow versions is: ", tf.__version__)
print("skimage versions is: ", skimage.__version__)
print("matplotlib versions is: ", matplotlib.__version__)

## step 2: examine the number of GPUs (CUDA)

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

## step 3: list all image files and load them, resize them

In [None]:
import glob as gb
cell_img_list = gb.glob("./dataset02/*cells.png")
dots_img_list = gb.glob("./dataset02/*dots.png")
(IMG_HEIGHT, IMG_WIDTH) = imread(dots_img_list[0]).shape

PADDING = 40
# NUM_TEST_IMAGES = 10

In [None]:
# =====================    
# Create X_test
# ===================== 

# create an empty matrix
# IMG_CHANNELS = 3; 
IMG_CHANNELS = 1;    # use gray scale images

# X_test = np.zeros((len(cell_img_list), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
X_test = np.zeros((len(cell_img_list), IMG_HEIGHT, IMG_WIDTH)) 
Y_test = np.zeros((len(dots_img_list), IMG_HEIGHT, IMG_WIDTH)) 

for i,j in enumerate(cell_img_list):
    X_test[i,:,:] = imread(cell_img_list[i], as_gray = True)
    Y_test[i,:,:] = imread(dots_img_list[i])     # already gray
print(X_test.shape)
(a1,b1,c1) = (X_test.shape)

In [None]:
# =====================    
# Resizing X_test
# ===================== 

# new size: square
IMG_HEIGHT_input  = 512
IMG_WIDTH_input   = 512
IMG_HEIGHT_output = 512
IMG_WIDTH_output  = 512

X_test_2 = np.zeros((len(cell_img_list), IMG_HEIGHT_input, IMG_WIDTH_input,   IMG_CHANNELS))
Y_test_2 = np.zeros((len(dots_img_list), IMG_HEIGHT_output, IMG_WIDTH_output, IMG_CHANNELS)) 

for i in range(len(cell_img_list)):
    X_test_2[i,:,:,0] = resize(X_test[i,:,:], (IMG_HEIGHT_input, IMG_WIDTH_input  ))
    Y_test_2[i,:,:,0] = resize(Y_test[i,:,:], (IMG_HEIGHT_output, IMG_WIDTH_output))


## skip to step 6 if no training is needed.

In [None]:
fig0, ax0 = plt.subplots(1,2, figsize = (10,5))
ax0[0].imshow(resize(X_test[0,:,:],(512,512)))
ax0[1].imshow(resize(Y_test[0,:,:],(512,512)))

## step 4: setting up the model, Wendi Xie's work

In [None]:
weight_decay = 1e-5

def get_crop_shape(target, refer):
    # width, the 3rd dimension
    cw = (target.get_shape()[2] - refer.get_shape()[2])
    assert (cw >= 0)
    if cw % 2 != 0:
        cw1, cw2 = int(cw/2), int(cw/2) + 1
    else:
        cw1, cw2 = int(cw/2), int(cw/2)
    
    # height, the 2nd dimension
    ch = (target.get_shape()[1] - refer.get_shape()[1])
    assert (ch >= 0)
    if ch % 2 != 0:
        ch1, ch2 = int(ch/2), int(ch/2) + 1
    else:
        ch1, ch2 = int(ch/2), int(ch/2)

    return (ch1, ch2), (cw1, cw2)

def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):
    """Function to add 2 convolutional layers with the parameters passed to it"""
    # first layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),
               kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # second layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),
               kernel_initializer = 'he_normal', padding = 'same')(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    return x

def get_unet(input_img, n_filters = 16, dropout = 0.1, batchnorm = True):
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPool2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPool2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPool2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)
    
    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPool2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)

#     u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c4)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)
    model = Model(inputs=[input_img], 
                  outputs=[outputs])
    return model

In [None]:
Input_img  = tf.keras.layers.Input((IMG_HEIGHT_input, IMG_WIDTH_input, IMG_CHANNELS))
model_unet = get_unet(Input_img, n_filters = 16, dropout = 0.1, batchnorm = True)

model_unet.summary()

In [None]:
# model = tf.keras.Model(inputs = [X_test_2], outputs = [Y_test_2])
filepath = "model.h5"

earlystopper = EarlyStopping(patience=15, verbose=1)
checkpoint = ModelCheckpoint(filepath, monitor='val_mean_squared_error',
                             verbose=1, mode='min')

callbacks_list = [earlystopper, checkpoint]
callbacks_list = [earlystopper]
model_unet.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

history = model_unet.fit(X_test_2, Y_test_2,
                         validation_split=0.15, batch_size=1, 
                         epochs=10, callbacks=callbacks_list
                         )

## step 5: saving the model

In [None]:
tf.keras.models.save_model(model_unet,"U-net.model.v3.h5")

## step 6: loading the model

In [None]:
model_unet = tf.keras.models.load_model("U-net.model.v3.h5")

In [None]:
# num_list = np.arange(0, 8)
# num_list = np.arange(-8, 0)
num_list = np.random.randint(102, size=8)

results = model_unet.predict(X_test_2[num_list,:,:,:])
img_1 = results[0,:,:,0]
img_2 = results[1,:,:,0]
img_3 = results[2,:,:,0]
img_4 = results[3,:,:,0]
img_5 = results[4,:,:,0]
img_6 = results[5,:,:,0]
img_7 = results[6,:,:,0]
img_8 = results[7,:,:,0]

In [None]:
# fig2, ax2 = plt.subplots(2,1,figsize=(8,14))
# ax2[0].imshow(X_test[-8,:,:])
# ax2[1].imshow(resize(img_1,(b1,c1)))

In [None]:
labeled_img_1 = label(resize(img_1,(b1,c1)) > 0.5)
labeled_img_2 = label(resize(img_2,(b1,c1)) > 0.5)
labeled_img_3 = label(resize(img_3,(b1,c1)) > 0.5)
labeled_img_4 = label(resize(img_4,(b1,c1)) > 0.5)
labeled_img_5 = label(resize(img_5,(b1,c1)) > 0.5)
labeled_img_6 = label(resize(img_6,(b1,c1)) > 0.5)
labeled_img_7 = label(resize(img_7,(b1,c1)) > 0.5)
labeled_img_8 = label(resize(img_8,(b1,c1)) > 0.5)

regions_1 = regionprops(labeled_img_1)
regions_2 = regionprops(labeled_img_2)
regions_3 = regionprops(labeled_img_3)
regions_4 = regionprops(labeled_img_4)
regions_5 = regionprops(labeled_img_5)
regions_6 = regionprops(labeled_img_6)
regions_7 = regionprops(labeled_img_7)
regions_8 = regionprops(labeled_img_8)
regions_total = [regions_1, regions_2, regions_3, regions_4,
                 regions_5, regions_6, regions_7, regions_8]

fig3, ax3 = plt.subplots(4,2,figsize=(15,30))
ax3 = np.ravel(ax3)
for k1, k2 in enumerate(num_list):
    ax3[k1].imshow(X_test[k2,:,:])

i0 = np.array([0]*8)

for k, region_l in enumerate(regions_total):
    for i, j in enumerate(region_l):
        if j.area > 30:
            ax3[k].plot(j.centroid[1], j.centroid[0], 'r+')
            i0[k] = i0[k] + 1

for k in range(8):
    ax3[k].set_title(("number of cells = "+str(i0[k])))
    
fig3.tight_layout()