# U-Net for Medical Image Segmentation
**Welcome to the U-Net tutorial for Image Segmentation!**

For this tutorial we will be using the U-Net architecture from the original U-Net paper by Olaf Ronneberger *et al*.

![U-Net from Olaf *et all*](img/Unet.png)

## Intro
U-Nets are a more recent development in deep learning and are very useful implementations of CNNs. They acquired there name because that have a series of convolutions (called the analysis path) and then deconvolutional steps (called the synthesis path) which reduces and expands the dimensionality of the layers in the network. As a result, they are often depicted as having a U-shape.

The analysis path serves to extract features from the input data and acts roughly like previously discussed neural networks. The input layer is passed to a convolutional layer, activation layer, pooling layer and then the next convolutional layer until dimensionality is reduced several times. There are two key differences however. (1) Before each pooling step occurs, an activation map is sent along a shortcut to a corresponding synthesis layer. (2) In the synthetic path, up-convolution is used to map high level features to an activation map from the same level in the analysis path. This activation map is concatenated to up-convolved data and is used to produce the next level of convolution, before up-convolution occurs again. This provides high resolutions features to the deconvolution layers which allows the U-Net to recover the important features in the input image while *filtering* the redundant ones. This is what make U-Net especially great in that image segmentation and causes it to excel at image segmentation because it is able to identify key structures without needing to even understand what these structures are. What's more is that these U-Nets can be combined into even larger networks and/or have other architecture added to them such as dense blocks which can improve their efficacy in many situations.

#### U-Net Paper
<div class="references">
<p>Olaf Ronneberger and Philipp Fischer and Thomas Brox. 2015. <em>U-Net: Convolutional Networks for Biomedical Image Segmentation</em>. <a href="https://arxiv.org/abs/1505.04597">arXiv: 1505.04597</a></p>
</div>

## Getting Started: Medical Image Segmentation with U-Net
**Welcome to the UNet tutorial for Chest X-Ray Image Segmentation!**
Medical image segmentation and medical image classification are very different tasks. Classification simply aims to classify images into a number of predefined groups, or classes. For example, the Medical Image Classification tutorial was made to classify chest x-ray images into those which displayed normal pathology or signs of pneumonia.<br/>
<br/>
Medical image segmentation, on the other hand aims to look at images and segment them into components, using a mask. Specifically, this tutorial will cover semantic segmentation in which we will try to train a fully convolutional neural network to create image masks from chest x-rays that cover the area of the image which shows the lungs; segmenting the image into lung tissue and non-lung tissue.<br/>
<br/>
The neural network being used here is based off work by *Roenneberger et al* in their paper, <a href="https://arxiv.org/abs/1505.04597">*U-Net: Convolutional Networks for Biomedical Image Segmentation*</a>. We should find that the network will produce much improved results compared to the other CNN tutorial.<br/>
<br/>
If you've not tried the CNN for Medical Image *Classification* tutorial, complete that one first as in contains a more introductory approach to CNNs. Furthermore, this tutorial uses data from Kaggle to segment lung tissue in chest x-rays. You can get the dataset <a href="https://www.kaggle.com/nikhilpandey360/chest-xray-masks-and-labels">HERE</a>.

**You will need to create a verified Kaggle account, download the data and change the source directory in this file.**

Let's get started! For brevity, since it was covered in the previous tutorial, will complete all the prep work below without explanation and jump right to the model definition.

In [None]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from skimage.io import imread, imshow
from skimage.transform import resize
from tqdm import tqdm

img_dims = 256 # should be atleast 256px
my_epochs = 20
batch_size = 32

inPath = "C:\Datasets\LungSegmentation\Lung Segmentation"

train_imgs = os.listdir(inPath + "/CXR_png")
mask_imgs = os.listdir(inPath + "/masks")
test_imgs = os.listdir(inPath + "/test")

X_train = np.zeros((len(train_imgs), img_dims, img_dims, 3), dtype=np.uint8)
Y_train = np.zeros((len(mask_imgs), img_dims, img_dims, 1), dtype=np.uint8)

for n, file in tqdm(enumerate(train_imgs), total=len(train_imgs)):
    t_img = imread(inPath + "/CXR_png/" + file)
    t_img = resize(t_img, (img_dims, img_dims, 3), mode='constant', preserve_range=True)
    X_train[n] = t_img
    if Path(inPath + "/masks/" + file).exists():
        m_mask = imread(inPath + "/masks/" + file)
    else:
        m_mask = imread(inPath + "/masks/" + file[:-4] + "_mask.png")
    m_mask = resize(m_mask, (img_dims, img_dims, 1), mode='constant', preserve_range=True)
    Y_train[n] = m_mask
    
X_test = np.zeros((len(train_imgs), img_dims, img_dims, 3), dtype=np.uint8)

for n, file in tqdm(enumerate(test_imgs), total=len(test_imgs)):
    test_img = imread(inPath + "/test/" + file)
    test_img = resize(test_img, (img_dims, img_dims, 3), mode='constant', preserve_range=True)
    X_test[n] = test_img

## Model Definition
Using the UNet diagram as a guide, we can define the model architecture using keras layers as with previous tutorials. The most important feature we need to concider here is the concatination which occurs in the ascending (decoder) path, as mentioned in the intro.

In [None]:
# Start by defining the input layer
inputs = tf.keras.layers.Input((img_dims, img_dims, 3))

#encoder level one (contraction)
e1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(inputs)
e1 = tf.keras.layers.Dropout(rate=0.2)(e1)
e1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(e1)

e2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(e1)

#encoder level 2
e2 = tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(e2)
e2 = tf.keras.layers.Dropout(rate=0.2)(e2)
e2 = tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(e2)

e3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(e2)

#encoder level 3
e3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', padding='same')(e3)
e3 = tf.keras.layers.Dropout(rate=0.2)(e3)
e3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', padding='same')(e3)

e4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(e3)

#encoder level 4
e4 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), activation='relu', padding='same')(e4)
e4 = tf.keras.layers.Dropout(rate=0.2)(e4)
e4 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), activation='relu', padding='same')(e4)

e5 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(e4)

#encoder level 5
e5 = tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), activation='relu', padding='same')(e5)
e5 = tf.keras.layers.Dropout(rate=0.2)(e5)
e5 = tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3), activation='relu', padding='same')(e5)

#decoder level 4
d4 = tf.keras.layers.Conv2DTranspose(filters=128, kernel_size=(2,2), strides=(2,2), padding='same')(e5)
d4 = tf.keras.layers.concatenate([d4, e4])
d4 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), activation='relu', padding='same')(d4)
d4 = tf.keras.layers.Dropout(rate=0.2)(d4)
d4 = tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3), activation='relu', padding='same')(d4)

#decoder level 3
d3 = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=(2,2), strides=(2,2), padding='same')(d4)
d3 = tf.keras.layers.concatenate([d3, e3])
d3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', padding='same')(d3)
d3 = tf.keras.layers.Dropout(rate=0.2)(d3)
d3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', padding='same')(d3)

#decoder level 2
d2 = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=(2,2), strides=(2,2), padding='same')(d3)
d2 = tf.keras.layers.concatenate([d2, e2])
d2 = tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(d2)
d2 = tf.keras.layers.Dropout(rate=0.2)(d2)
d2 = tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(d2)

#decoder level 1
d1 = tf.keras.layers.Conv2DTranspose(filters=16, kernel_size=(2,2), strides=(2,2), padding='same')(d2)
d1 = tf.keras.layers.concatenate([d1, e1])
d1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(d1)
d1 = tf.keras.layers.Dropout(rate=0.2)(d1)
d1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(d1)

#output layer
outputs = tf.keras.layers.Conv2D(filters=1, kernel_size=(1,1), activation='sigmoid')(d1)

model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

# Callbacks
checkpoint = ModelCheckpoint(filepath='xray_model.hdf5', save_best_only=True, save_weights_only=True)
lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=2, verbose=2, mode='max')
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.1, patience=3, mode='min')
board = TensorBoard(log_dir='logs')

my_callbacks = [
    checkpoint,
    #lr_reduce,
    early_stop,
    board
]

## Training
With the model defined, we can now train it with the data.

In [None]:
hist = model.fit(
            X_train,
            Y_train,
            validation_split=0.1,
            epochs=my_epochs,
            callbacks=my_callbacks
)

## Model and Training Data
We can plot training data here.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 3))
ax = ax.ravel()

for i, met in enumerate(['accuracy', 'loss']):
    ax[i].plot(hist.history[met])
    ax[i].plot(hist.history['val_' + met])
    ax[i].set_title('Model {}'.format(met))
    ax[i].set_xlabel('epochs')
    ax[i].set_ylabel(met)
    ax[i].legend(['train', 'val'])

## Making a Prediction

In [None]:
img_path = 'CHNCXR_0025_0.png'
img = image.load_img(img_path, target_size=(img_dims, img_dims))
pred = image.img_to_array(img)
pred = np.expand_dims(pred, axis=0)
pred = pred.astype('float32')/255

prediction = model.predict(pred)

plt.imshow(prediction[0])

## Saving the Model
Finally, we should save the model for future use.

In [None]:
## model.save('path/to/file/filename')
model.save('UNet_Segmentation_Model')