<a href="https://colab.research.google.com/github/slowvak/AI-Deep-Learning-Lab/blob/master/Image_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Magicians Corner: Image Segmentation with tf.keras and U-Nets
Full Article can be found here:

<table class="tfo-notebook-buttons" align="left"><td>
<a target="_blank"  href="http://colab.research.google.com/github/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank"  href="https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

In [None]:
#Cell #1 --  Set up environment by loading libraries.
# Be sure to check that you have GPU runtime
import os
import glob
import zipfile
import functools

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (12,12)

from sklearn.model_selection import train_test_split
import matplotlib.image as mpimg
import pandas as pd

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K

import tensorflow as tf


# Load the {{PROCESSED}} images and the masks for pancreas from RSNA repository
# To make load times reasonable, only slices with pancreas are included
#  and these are 8 bit JPEG images (compresssed)

!rm -rf ./MC4-TensorflowUNet
!git clone https://github.com/slowvak/MC4-TensorflowUNet.git

!rm -rf images
!mkdir images
!rm -rf masks
!mkdir masks


limit = 400  # limit number of subjects due to GPU memory limits


for f in os.listdir('./MC4-TensorflowUNet'):
    cmd = 'unzip ./MC4-TensorflowUNet/{}'.format(f)
    os.system(cmd)
    limit = limit - 1
    if limit < 0:
        break
        
!mv *-Mask.jpg ./masks
!mv *.jpg ./images


In [None]:
#Cell #2 -- We will now split the data into train and test. The syntax is fairly different from FastAI
import sys
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

IMG_HEIGHT = IMG_WIDTH =256
IMG_CHANNELS = 1
img_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)

# Get and resize train images and masks
image_filenames = os.listdir('./images')
mask_filenames = os.listdir('./masks')
num_train_examples = int(len(image_filenames) * 0.8) - 1
num_val_examples = int(len(image_filenames) - num_train_examples)
X_train = np.zeros((num_train_examples, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_train = np.zeros((num_train_examples, IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool)
X_test = np.zeros((num_val_examples, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_test = np.zeros((num_val_examples, IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool)
print('Getting and resizing train images and masks ... ')

sys.stdout.flush()


n = 0
i = 0
for f in mask_filenames:
    mask = imread('./masks/' + f)
#    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    mask = np.expand_dims(mask, axis=2)
    img = imread('./images/' + f.replace("-Mask.", "."))
#    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    img = np.expand_dims(img, axis=2)
    if n < num_train_examples:
        X_train[n] = img
        Y_train[n] = mask
    else:
        X_test[i] = img
        Y_test[i] = mask
        i = i+1
    n = n + 1

print("Number of training examples: {}".format(num_train_examples))
print("Number of validation examples: {}".format(num_val_examples))

Display some of the images in our dataset. 

In [None]:
#Cell #3
display_num = 5

r_choices = np.random.choice(num_train_examples, display_num)

plt.figure(figsize=(10, 15))
for i in range(0, display_num * 2, 2):
  img_num = r_choices[i // 2]
  
  plt.subplot(display_num, 2, i + 1)
  plt.imshow(np.reshape(X_train[img_num], (256,256)), cmap='gray')
  plt.title("Original Image")
  
  plt.subplot(display_num, 2, i + 2)
  plt.imshow(np.reshape(Y_train[img_num], (256,256)), cmap='gray')
  plt.title("Masked Image")  
  
plt.suptitle("Examples of Images and their Masks")
plt.show()

In [None]:
#!zip -r images.zip ./images/*
!zip -r masks.zip ./masks/*


In [None]:
# Cell 4
# Build U-Net model
act_fn = 'relu'
init_fn = 'he_normal'

def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

#binary cross entropy is another function that can perform well
def bce_dice_loss(y_true, y_pred):
    loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss



inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = Lambda(lambda x: x / 255) (inputs)

c1 = Conv2D(16, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (s)
c1 = Dropout(0.1) (c1)
c1 = Conv2D(16, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c1)
p1 = MaxPooling2D((2, 2)) (c1)

c2 = Conv2D(32, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (p1)
c2 = BatchNormalization()(c2)
c2 = Dropout(0.1) (c2)
c2 = Conv2D(32, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c2)
p2 = MaxPooling2D((2, 2)) (c2)

c3 = Conv2D(64, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (p2)
c3 = Dropout(0.2) (c3)
c3 = Conv2D(64, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c3)
p3 = MaxPooling2D((2, 2)) (c3)

c4 = Conv2D(128, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (p3)
c4 = Dropout(0.2) (c4)
c4 = Conv2D(128, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

c5 = Conv2D(256, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (p4)
c5 = Dropout(0.3) (c5)
c5 = Conv2D(256, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c5)

u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(128, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (u6)
c6 = Dropout(0.2) (c6)
c6 = Conv2D(128, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c6)

u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(64, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (u7)
c7 = Dropout(0.2) (c7)
c7 = Conv2D(64, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c7)

u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(32, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (u8)
c8 = Dropout(0.1) (c8)
c8 = Conv2D(32, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c8)

u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(16, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (u9)
c9 = Dropout(0.1) (c9)
c9 = Conv2D(16, (3, 3), activation=act_fn, kernel_initializer=init_fn, padding='same') (c9)

outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)

model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[dice_loss])
model.summary()

# Defining custom metrics and loss functions
Defining loss and metric functions are simple with Keras. Simply define a function that takes both the True labels for a given example and the Predicted labels for the same given example. 
Dice loss is a metric that measures overlap. More info on optimizing for Dice coefficient (our dice loss) can be found in the paper, where it was introduced.

We use dice loss here because it performs better at class imbalanced problems by design. In addition, maximizing the dice coefficient and IoU metrics are the actual objectives and goals of our segmentation task. Using cross entropy is more of a proxy which is easier to maximize. Instead, we maximize our objective directly.

In [None]:
#Cell 5
epochs = 50
batch_size = 32

# Fit model
earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('model.h5', verbose=1, save_best_only=True)
results = model.fit(X_train, Y_train, validation_split=0.2, batch_size=batch_size, epochs=epochs, 
                    callbacks=[earlystopper, checkpointer])

In [None]:
#Cell 6
# Predict on train, val and test
model = load_model('model.h5', custom_objects={'dice_loss': dice_loss})
preds_train = model.predict(X_train[:int(X_train.shape[0]*0.9)], verbose=1)
preds_val = model.predict(X_train[int(X_train.shape[0]*0.9):], verbose=1)
preds_test = model.predict(X_test, verbose=1)

# Threshold predictions
preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_val_t = (preds_val > 0.5).astype(np.uint8)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

# Create list of upsampled test masks
preds_test_upsampled = []
for i in range(len(preds_test)):
    preds_test_upsampled.append(resize(np.squeeze(preds_test[i]), (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True))
    

In [None]:
#Cell 7
# Perform a sanity check on some random training samples
import random

ix = random.randint(0, len(preds_train_t))
imshow(np.squeeze(X_train[ix]))
plt.show()
imshow(np.squeeze(Y_train[ix]))
plt.show()
imshow(np.squeeze(preds_train_t[ix]))
plt.show()