<a href="https://colab.research.google.com/github/ICRAR/PHYS5511/blob/master/2019/week09/lung_segmentation_from_chest_x_ray_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Lung segmentation from Chest X-Ray dataset

**About the data**:
- The [dataset](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4256233/) is made up of images and segmentated mask from two diffrent sources.
- There is a slight abnormality in naming convention of masks.
- Some images don't have their corresponding masks.
- Images from the Shenzhen dataset has apparently smaller lungs as compared to the Montgomery dataset.



This Notebook was adapted from this [Kaggle kernel](https://www.kaggle.com/nikhilpandey360/lung-segmentation-from-chest-x-ray-dataset/notebook).

In [0]:
from google.colab import drive
drive.mount('/content/drive')

#Setup data

##Download data first

In [0]:
%cd /content/drive/My\ Drive/PHYS5512/data
!mkdir lung_segmentation

In [0]:
%cd lung_segmentation

In [0]:
!cp /content/drive/My\ Drive/PHYS5512/kaggle.json /root/.kaggle/
!kaggle datasets download -d nikhilpandey360/chest-xray-masks-and-labels

#Goto the directory

In [0]:
%cd /content/drive/My\ Drive/PHYS5512/data/lung_segmentation

#Data pre-processing

In [0]:
import numpy as np
from zipfile import ZipFile
import pandas as pd
from tqdm import tqdm
import os
import os.path as osp
from cv2 import imread, createCLAHE 
import cv2
from glob import glob
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Image



In [0]:
zipfname = 'chest-xray-masks-and-labels.zip'
archive = ZipFile(zipfname)
print(len(archive.namelist()))
all_files = archive.namelist()
all_files[0:10]


In [0]:
image_path_prefix = 'Lung Segmentation/CXR_png'
mask_path_prefix = 'Lung Segmentation/masks'


They can inspected the concerning dataset seperately [here](http://https://www.kaggle.com/kmader/pulmonary-chest-xray-abnormalities/home).

In [0]:
# we have 704 masks but 800 images. Hence we are going to
# make a 1-1 correspondance from mask to images, not the usual other way.
images = [x.split('/')[-1] for x in all_files if x.startswith(image_path_prefix)][1:]
masks = [x.split('/')[-1] for x in all_files if x.startswith(mask_path_prefix)][1:]
print(images[0:3])
print(masks[0:3])


In [0]:
mask = [fName.split(".png")[0] for fName in masks]
image_file_name = [fName.split("_mask")[0] for fName in mask]

In [0]:
image_file_name[0:3]

In [0]:
mask[0:3]

In [0]:
check = [i for i in mask if "mask" in i]
print("Total mask that has modified name:",len(check))

In [0]:
check[0:3]

In [0]:
testing_files = set(images) & set(masks)

In [0]:
list(testing_files)[0:3]

Earlier I was going to train on the Shenzhen dataset while performing prediction on the Montgomery dataset. However, the nature of the data was different in both the set. The images from Shenzhen dataset had **smaller** lung-to-image ratio as compared to the Montgomery dataset.

Thus, I am loading the two dataset seperately which I combined once the disparity is understood.

In [0]:
testing_files = set(images) & set(masks)
training_files = check

def get_data(X_shape, flag = "test"):
    im_array = []
    mask_array = []
    
    if flag == "test":
        for i in tqdm(testing_files):
            fi = archive.read(osp.join(image_path_prefix, i))
            img = cv2.imdecode(np.frombuffer(fi, np.uint8), 1)
            im = cv2.resize(img, (X_shape,X_shape))[:, :, 0]
            
            fm = archive.read(osp.join(mask_path_prefix, i))
            img_fm = cv2.imdecode(np.frombuffer(fm, np.uint8), 1)
            mask = cv2.resize(img_fm, (X_shape, X_shape))[:, :, 0]
            
            im_array.append(im)
            mask_array.append(mask)
        
        return im_array, mask_array
    
    if flag == "train":
        for i in tqdm(training_files): 
            fi = archive.read(osp.join(image_path_prefix, 
                                       i.split("_mask")[0] + ".png"))
            img = cv2.imdecode(np.frombuffer(fi, np.uint8), 1)
            im = cv2.resize(img, (X_shape,X_shape))[:,:,0]
            
            fm = archive.read(osp.join(mask_path_prefix,i + ".png"))
            img_fm = cv2.imdecode(np.frombuffer(fm, np.uint8), 1)
            mask = cv2.resize(img_fm, (X_shape, X_shape))[:,:,0]

            im_array.append(im)
            mask_array.append(mask)

        return im_array, mask_array

In [0]:
#perform sanity check

def plot_mask(X, y):
    sample = []
    
    for i in range(6):
        left = X[i]
        right = y[i]
        combined = np.hstack((left, right))
        sample.append(combined)
        
    #plt.figure(figsize=(20, 40))
    for i in range(0, 6, 2):
        ss = sample[i].shape
        h, w = ss[0], ss[1]
        plt.figure(figsize=(25, 10))
        
        plt.subplot(3, 2, 1 + i)
        plt.imshow(sample[i].reshape([h, w]))
        
        plt.subplot(3, 2, 2 + i)
        plt.imshow(sample[i + 1].reshape([h, w]))
        
        #plt.subplot(2, 3, 3 + i)
        #plt.imshow(sample[i + 2].reshape([h, w]))
        
        plt.show()

In [0]:
# Load training and testing data
dim = 256 * 2
X_train,y_train = get_data(dim, flag="train")
X_test, y_test = get_data(dim)

##Perform Sanity Check

It is prudent to perform sanity check of the data correspondance. It become a routine check-up after a while but it is very crucial to check if we had made a mistake in loading the data.

In [0]:
print("training set")
plot_mask(X_train, y_train)
print("testing set")
plot_mask(X_test, y_test)

Both the sets looks correct. Let's combine them and further use them as a unified dataset.

In [0]:
X_train = np.array(X_train).reshape(len(X_train),dim,dim,1)
y_train = np.array(y_train).reshape(len(y_train),dim,dim,1)
X_test = np.array(X_test).reshape(len(X_test),dim,dim,1)
y_test = np.array(y_test).reshape(len(y_test),dim,dim,1)
assert X_train.shape == y_train.shape
assert X_test.shape == y_test.shape

In [0]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

In [0]:
np.save('X_train.npy', X_train)
np.save('X_test.npy', X_test)
np.save('y_train.npy', y_train)
np.save('y_test.npy', y_test)

#Start from here after pre-processing

In [0]:
X_train = np.load('X_train.npy')
X_test = np.load('X_test.npy')
y_train = np.load('y_train.npy')
y_test = np.load('y_test.npy')

In [0]:
images = np.concatenate((X_train, X_test), axis=0)
mask  = np.concatenate((y_train, y_test), axis=0)

In [0]:
pick_n = 4
#ind_check = np.random.choice(len(images), pick_n)
ind_check = np.array([622, 170, 611, 309])
plt.figure(figsize=(pick_n * 3, pick_n * 6))
for i, idx in enumerate(ind_check):
  plt.subplot(pick_n, 2, i * 2 + 1)
  plt.imshow(images[idx].reshape([512, 512]))
  plt.axis('off')
  plt.subplot(pick_n, 2, i * 2 + 2)
  plt.imshow(mask[idx].reshape([512, 512]))
  plt.axis('off')  
plt.tight_layout()

#Define  the network and callbacks

We are going to use the widely cited [U-Net model](https://arxiv.org/abs/1505.04597) to solve the image segmentation problem. Please read this [excellent overview](https://www.jeremyjordan.me/semantic-segmentation) to understand some basic concepts of semantic image segmentation.

First, let us download and visualise the U-net architecture.

In [0]:
!wget --no-check-certificate https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png

In [0]:
Image(filename='u-net-architecture.png', width=1050)

##Upsampling via Transpose convolution
Before we build the full U-Net model, we need to understand a bit more on how **leanable upsampling** works (i.e. those **green up-conv** 2x2 arrows in the above architecture diagram) in U-net. Transposed Convolution is the most preferred choice to perform up sampling, which basically learns parameters through back propagation to convert a low resolution image to a high resolution image.

In [0]:
# example of using the transpose convolutional layer
# adapted from 
# https://machinelearningmastery.com/upsampling-and-transpose-convolution-layers-for-generative-adversarial-networks/
from numpy import asarray
from keras.models import Sequential
from keras.layers import Conv2DTranspose
# define input data
X = asarray([[1, 2],
             [3, 4]])
# show input data for context
print(X)
# reshape input data into one sample a sample with a channel
X = X.reshape((1, 2, 2, 1))
# define model
model = Sequential()
model.add(Conv2DTranspose(1, (1,1), strides=(2,2), input_shape=(2, 2, 1)))
# summarize the model
model.summary()
# define weights that they do nothing
# weights = [asarray([[[[1]]]]), asarray([0])]
# store the weights in the model
#model.set_weights(weights)
# make a prediction with the model
yhat = model.predict(X)
print(yhat.shape)
# reshape output to remove channel to make printing easier
yhat = yhat.reshape((4, 4))
# summarize output
print(yhat)

SO why do we really need such "up-sampling" at all? The section ***iii) Need for up sampling*** in [this article](https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47) provides some intuitive answer:

The output of semantic segmentation is not just a class label or some bounding box parameters. In-fact the output is a complete high resolution image in which all the pixels are classified.
Thus if we use a regular convolutional network with pooling layers and dense layers, we will lose the “WHERE” information and only retain the “WHAT” information which is not what we want. In case of segmentation we need both “WHAT” as well as “WHERE” information.
Hence there is a need to up sample the image, i.e. convert a low resolution image to a high resolution image to recover the “WHERE” information.

So here is the complete U-Net model. Compared to Keras models developed in prevoius tutorials, this U-Net model has two distinctions:

1. it uses the [Keras functional API](https://keras.io/getting-started/functional-api-guide/) to define models
2. it defines a new loss function - dice coefficient loss

In [0]:
from keras.models import Model
from keras.layers import concatenate, Conv2D, MaxPooling2D
from keras.optimizers import Adam
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, LearningRateScheduler


def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)

def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def unet(input_size=(256, 256, 1)):
    inputs = Input(input_size)
    
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    return Model(inputs=[inputs], outputs=[conv10])

Questions:
1. Why use Sigmoid as the final activation function?
2. What is the rationale behind Dice Coefficient?
3. Why do we have to use functional API?

To get a more detailed understanding of U-net, please see the ***Points to note*** section in [this article](https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47). Most importantly, the "skip connection", which is implemented as keras.concatenate, is important for recovering location information:
To get better precise locations, at every step of the decoder we use skip connections by concatenating the output of the transposed convolution layers with the feature maps from the Encoder at the same level:

*   u6 = u6' + c4, u6' = upsample(c5)
*   u7 = u7' + c3, u7' = upsample(c6)
*   u8 = u8' + c2, u8' = upsample(c7)
*   u9 = u9' + c1, u9' = upsample(c8)


After every concatenation we again apply two consecutive regular convolutions so that the model can learn to assemble a more precise output

In [0]:
!wget -O u-net-upsampling.png  https://www.jeremyjordan.me/content/images/2018/05/Screen-Shot-2018-05-20-at-12.26.53-PM.png

In [0]:
Image(filename='u-net-upsampling.png', width=1000)

## Compile and train the Unet Model

In [0]:
model = unet(input_size=(512, 512, 1))
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss,
                  metrics=[dice_coef, 'binary_accuracy'])
model.summary()

In [0]:
from keras.utils import plot_model
plot_model(model, to_file='model.png', show_layer_names=False, show_shapes=True)

The above diagram shows a "rotated" U-shape network. Please pay attention to the shape information for each layer (input and output) 

## Callbacks, Early Stopping and Reduced LR


In [0]:
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau
weight_path="{}_weights.best.hdf5".format('cxr_reg')

checkpoint = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only = True)

reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 
                                   patience=3, 
                                   verbose=1, mode='min', epsilon=0.0001, cooldown=2, min_lr=1e-6)
early = EarlyStopping(monitor="val_loss", 
                      mode="min", 
                      patience=15) # probably needs to be more patient, but kaggle time is limited
callbacks_list = [checkpoint, early, reduceLROnPlat]

## Train the model

I intially used a 60-40 train-test spit and got a loss of -0.97. However, the better way to do it is 80-10-10 train-test-validation spit. Below I am roughly doing the later.

In [0]:
from IPython.display import clear_output
from keras.optimizers import Adam 
from sklearn.model_selection import train_test_split

model.compile(optimizer=Adam(lr=2e-4), 
              loss=[dice_coef_loss], 
           metrics = [dice_coef, 'binary_accuracy'])

train_vol, test_vol, train_seg, test_seg = train_test_split((images - 127.0) / 127.0, 
                                                            (mask > 127).astype(np.float32), 
                                                            test_size=0.1, 
                                                            random_state=2018)

train_vol, val_vol, train_seg, val_seg = train_test_split(train_vol, 
                                                          train_seg, 
                                                          test_size=0.1, 
                                                          random_state=2018)

loss_history = model.fit(train_vol, train_seg,
                         batch_size=16, epochs=5,
                         validation_data=(val_vol, val_seg),
                         callbacks=callbacks_list)


clear_output()

#Plot the metric and evaluate 

In [0]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (15, 5))
ax1.plot(loss_history.history['loss'], '-', label='Loss')
ax1.plot(loss_history.history['val_loss'], '-', label='Validation Loss')
ax1.legend()

ax2.plot(100 * np.array(loss_history.history['binary_accuracy']), '-', 
         label = 'Train Binary Accuracy')
ax2.plot(100 * np.array(loss_history.history['val_binary_accuracy']), '-',
         label = 'Validation Binary Accuracy')
ax2.legend()

ax3.plot(100 * np.array(loss_history.history['dice_coef']), '-', 
         label = 'Train Dice Accuracy')
ax3.plot(100 * np.array(loss_history.history['val_dice_coef']), '-',
         label = 'Validation Dice Accuracy')
ax3.legend(loc='best')

**Question** - Why Binary Accuracy is always higher than Dice accuracy?

#Test the model

In [0]:
pred_candidates = np.random.randint(1, test_vol.shape[0], 10)
preds = model.predict(test_vol)

plt.figure(figsize=(20,10))

for i in range(0,9,3):
    plt.subplot(3,3,i+1)
    
    plt.imshow(np.squeeze(test_vol[pred_candidates[i]]))
    plt.xlabel("Base Image")
    
    
    plt.subplot(3, 3, i + 2)
    plt.imshow(np.squeeze(test_seg[pred_candidates[i]]))
    plt.xlabel("Mask")
    
    plt.subplot(3, 3, i + 3)
    plt.imshow(np.squeeze(preds[pred_candidates[i]]))
    plt.xlabel("Pridiction")