# Setting up Vanilla U-Net to run on X-Ray Image segmentation

## Part 1: Importing required libraries and data preparation

In [None]:
!unzip MontgomerySet.zip

Archive:  MontgomerySet.zip
  inflating: MontgomerySet/.DS_Store  
   creating: MontgomerySet/ClinicalReadings/
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0001_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0002_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0003_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0004_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0005_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0006_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0008_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0011_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0013_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0015_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0016_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0017_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0019_0.txt  
  inflating: MontgomerySet/ClinicalReadings/MCUCXR_0020_

In [None]:
#Imports
import os
import numpy as np
import cv2
from glob import glob
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Recall, Precision
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model

In [None]:
def load_data(path, split=0.1):
    """ 
    Function to read the links of the images in followed by a 
    train-test-split into train, test and validation datasets respectively.
    
    Parameters
      ----------
      path : path to images
          a string containing the path to the directory where images are stored

      split : Ratio of test, train split
          A floating point variable (between 0 and 1) containg the value ratio
          of train and test split of the total images in the ratio 1-split and 
          split respectively

    Returns
      ----------
      data : data split into train, validation, and test
          Returns the dataset split into train, validation, and test with each
          containing the original (input) image, the left lung segmented image
          and the right lung segmented image (outputs)
          
    """

    # Takes the list of original (input) images
    images = sorted(glob(os.path.join(path, "CXR_png", "*.png")))
    # Takes the list of left lung segmented images (output_l)
    masks_l = sorted(glob(os.path.join(path, "ManualMask", "leftMask", "*.png")))
    # Takes the list of right lung segmented images (output_r)
    masks_r = sorted(glob(os.path.join(path, "ManualMask", "rightMask", "*.png")))
    # Using the split ratio to calculate total length in train
    split_size = int(len(images) * split)

    #Splitting images into train and val
    train_x, val_x = train_test_split(images, 
                                      test_size=split_size, 
                                      random_state=42)
    #Splitting left lung into train and val
    train_y_l, val_y_l = train_test_split(masks_l, 
                                          test_size=split_size, 
                                          random_state=42)
    #Splitting images into train and val
    train_y_r, val_y_r = train_test_split(masks_r, 
                                          test_size=split_size, 
                                          random_state=42)
    


    #Splitting train images into train and test
    train_x, test_x = train_test_split(train_x, 
                                       test_size=split_size, 
                                       random_state=42)
    
    #Splitting train left lung segmentation into train and test
    train_y_l, test_y_l = train_test_split(train_y_l, 
                                           test_size=split_size, 
                                           random_state=42)
    
    #Splitting train right lung segmentation into train and test
    train_y_r, test_y_r = train_test_split(train_y_r, 
                                           test_size=split_size, 
                                           random_state=42)

    # Splitted ratios into 81,10,9
    return (train_x, train_y_l, train_y_r), (val_x, val_y_l, val_y_r), (test_x, test_y_l, test_y_r)


In [None]:
def imageread(path,width=512,height=512):
    """ 
    Function to read the images and to resize them into specified dimensions
    and return the normalized image (pixel values)
    
    Parameters
      ----------
      path : path to images
          a string containing the path to the directory where images are stored

      width : Width dimension of the resized image

      height : Height dimension of the resized image

    Returns
      ----------
      x : Normalized and resized image
          
    """
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (width, height))
    x = x/255.0
    x = x.astype(np.float32)
    return x

In [None]:
def maskread(path_l, path_r,width=512,height=512):
    """ 
    Function to read the segmented lung images and to resize them into 
    specified dimensions combining the two (left and right segmentation) and 
    return the binarized image (pixel values)
    
    Parameters
      ----------
      path_l : path to images
          a string containing the path to the directory where left lung 
          segmentation images are stored
      
      path_r : path to images
          a string containing the path to the directory where right lung 
          segmentation images are stored

      width : Width dimension of the resized image

      height : Height dimension of the resized image

    Returns
      ----------
      x : Binarized and resized joint image
          
    """
    x_l = cv2.imread(path_l, cv2.IMREAD_GRAYSCALE)
    x_r = cv2.imread(path_r, cv2.IMREAD_GRAYSCALE)
    x = x_l + x_r
    x = cv2.resize(x, (width, height))
    x = x/np.max(x)
    x = x > 0.5
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=-1)
    return x

In [None]:
def tf_parse(x, y_l, y_r):
    """ 
    Function to perform decoding of the image and mask objects, 
    and call above defined operations of decoding and resizing.
    
    Parameters
      ----------
      x : path to images
          a string containing the path to the directory where images are stored
      
      y_l : path to images
          a string containing the path to the directory where left lung 
          segmentation masks are stored

      y_r : path to images
          a string containing the path to the directory where right lung 
          segmentation masks are stored

    Returns
      ----------
      x : Binarized and resized image tensor of required shape
      y : Binarized and resized joint mask tensor of required shape
          
    """
    def _parse(x, y_l, y_r):
        x = x.decode() # convert bytes to string object
        y_l = y_l.decode()
        y_r = y_r.decode()
        x = imageread(x) # read and resize images
        y = maskread(y_l, y_r) # read and resize masks
        return x, y
    x, y = tf.numpy_function(_parse, [x, y_l, y_r], [tf.float32, tf.float32]) # wrap the python function to use it as a tensorflow op.
    x.set_shape([512, 512, 3]) # update the shape of the tensor
    y.set_shape([512, 512, 1])
    return x, y

In [None]:
def tf_dataset(X, Y_l, Y_r, batch=8):
    """ 
    Function to Prepare the dataset for Training.
    
    Parameters
      ----------
      X : path to images
          a collection of strings containing the path to the directory where images are stored
      
      Y_l : path to images
          a collection of strings containing the path to the directory where left lung 
          segmentation masks are stored

      Y_r : path to images
          a collection of strings containing the path to the directory where right lung 
          segmentation masks are stored

    Returns
      ----------
      dataset : Prepared Dataset object.
          
    """
    dataset = tf.data.Dataset.from_tensor_slices((X, Y_l, Y_r))
    dataset = dataset.shuffle(buffer_size=200)
    dataset = dataset.map(tf_parse) # apply tf_parse to each element in the dataset.
    dataset = dataset.batch(batch) # combine consequitive elements from the dataset into batches
    dataset = dataset.prefetch(buffer_size=4) #allows later elements to be prepared while the current element is being processed
    return dataset

## Part 2: Setting up hyperparameters and building the model

In [None]:
""" Define Training Hyperparameters """
batch_size = 2
lr = 1e-5
epochs = 30
model_path = "/content/model.h5"

In [None]:
""" Load the Dataset (Training, Validation and Test sets"""
dataset_path = 'MontgomerySet'
(train_x, train_y_l, train_y_r), (val_x, val_y_l, val_y_r), (test_x, test_y_l, test_y_r) = load_data(dataset_path)

In [None]:
'''Prepare the Training and Validation sets'''
train_dataset = tf_dataset(train_x, train_y_l, train_y_r, batch=batch_size)
val_dataset = tf_dataset(val_x, val_y_l, val_y_r, batch=batch_size)

In [None]:
""" Defining loss functions and evaluation metrics """
def iou(y_true, y_pred):
    '''
    Function to caluclate the Intersection Over Union (IOU) score 
    between the True Label and Predicted Label.

    Parameters
        ----------
        y_true : True Segmentation Mask in Binary
        
        y_pred : Predicted Segmentation Mask in Binary

      Returns
        ----------
        An IOU score between 0 and 1.
    '''
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32) # Apply the python function as a tensorflow op.

def dice_coef(y_true, y_pred):
    '''
    Function to caluclate the Dice Similarity Coefficient
    between the True Label and Predicted Label.

    Parameters
        ----------
        y_true : True Segmentation Mask in Binary
        
        y_pred : Predicted Segmentation Mask in Binary

      Returns
        ----------
        The Dice Similarity Coefficient between 0 and 1.
    '''
    smooth = 1e-15
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def dice_loss(y_true, y_pred):
    '''
    Function to caluclate the Dice loss using the Dice Similarity Coefficient
    between the True Label and Predicted Label.

    Parameters
        ----------
        y_true : True Segmentation Mask in Binary
        
        y_pred : Predicted Segmentation Mask in Binary

      Returns
        ----------
        The Dice Loss between 0 and 1.
    '''
    return 1.0 - dice_coef(y_true, y_pred)

In [None]:
def conv_block(input, num_filters):
    """ 
    Function to build a single Convolution Block consisitng of
    two Conv2D layers, each with Batch Normalisation and Activation layers
    applied consecutively.
    
    Parameters
      ----------
      input : image tensors of appropriate shape.
      
      num_filters : Integer, the dimensionality of the output space 
          (i.e. the number of output filters in the convolution).

    Returns
      ----------
      x : The twice convolved tensor output
          
    """
    x = Conv2D(num_filters, kernel_size=3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, kernel_size=3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(input, num_filters):
    """ 
    Function to build a single Encoder Block consisitng of
    a conv_block and a Max Pooling layer.
    
    Parameters
      ----------
      input : image tensors of appropriate shape.
      
      num_filters : Integer, the dimensionality of the output space 
          (i.e. the number of output filters in the convolution).

    Returns
      ----------
      x : The twice convolved tensor output
      p : Max-pooled output of conv_block
          
    """
    x = conv_block(input, num_filters) # used in the skip connection
    p = MaxPool2D(pool_size=(2, 2))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    """ 
    Function to build a single Decoder Block consisitng of
    a transpose operation and a conv_block. We also concatenate the transposed features with
    corresponding features from encoder (skip connection).
    
    Parameters
      ----------
      input : image tensors of appropriate shape.

      skip_features : corresponding feature from encoder.
      
      num_filters : Integer, the dimensionality of the output space 
          (i.e. the number of output filters in the convolution).

    Returns
      ----------
      x : The upsampled output of the decoder block 
          
    """
    x = Conv2DTranspose(num_filters, kernel_size=(2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [None]:
def build_unet(input_shape):
    """ 
    Function to build the U-Net model to train.
    
    Parameters
      ----------
      input_shape : shape of input tensor.

    Returns
      ----------
      model : The Model object
          
    """
    inputs = Input(shape=input_shape) # Instantiate tensor with shape of input.

    '''Encoder'''
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    '''Decoder'''
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [None]:
'''Define the model and compile with appropriate loss function and optimizer algorithm'''
model = build_unet((512, 512, 3))
'''Set metrics to watch during Training'''
metrics = [dice_coef, iou, Recall(), Precision()]

# In this implememtation we will train the model using the Binary CrossEntropy loss
model.compile(loss='binary_crossentropy', optimizer=Adam(lr), metrics=metrics)

In [None]:
from tensorflow import keras
keras.utils.plot_model(model, "U-Net.png", show_shapes=True)

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.


In [None]:
model.summary()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 512, 512, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 512, 512, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

                                                                                                  
 conv2d_8 (Conv2D)              (None, 32, 32, 1024  4719616     ['max_pooling2d_3[0][0]']        
                                )                                                                 
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 1024  4096       ['conv2d_8[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_8 (Activation)      (None, 32, 32, 1024  0           ['batch_normalization_8[0][0]']  
                                )                                                                 
                                                                                                  
 conv2d_9 

                                                                                                  
 activation_15 (Activation)     (None, 256, 256, 12  0           ['batch_normalization_15[0][0]'] 
                                8)                                                                
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 512, 512, 64  32832      ['activation_15[0][0]']          
 spose)                         )                                                                 
                                                                                                  
 concatenate_3 (Concatenate)    (None, 512, 512, 12  0           ['conv2d_transpose_3[0][0]',     
                                8)                                'activation_1[0][0]']           
                                                                                                  
 conv2d_16

## Part 3: Training the model and evaluating it on the test set

In [None]:
'''Set the callbacks to use during Training'''
callbacks = [
        ModelCheckpoint(model_path, verbose=1, save_best_only=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-8, verbose=1)
        ]

In [None]:
'''Start the training'''
history = model.fit(
        train_dataset,
        epochs=epochs,
        validation_data=val_dataset,
        callbacks=callbacks
    )

Epoch 1/30
Epoch 1: val_loss improved from inf to 0.75037, saving model to /content\model.h5
Epoch 2/30
Epoch 2: val_loss did not improve from 0.75037
Epoch 3/30
Epoch 3: val_loss improved from 0.75037 to 0.73594, saving model to /content\model.h5
Epoch 4/30
Epoch 4: val_loss improved from 0.73594 to 0.67099, saving model to /content\model.h5
Epoch 5/30
Epoch 5: val_loss improved from 0.67099 to 0.61033, saving model to /content\model.h5
Epoch 6/30
Epoch 6: val_loss improved from 0.61033 to 0.58439, saving model to /content\model.h5
Epoch 7/30
Epoch 7: val_loss did not improve from 0.58439
Epoch 8/30
Epoch 8: val_loss improved from 0.58439 to 0.57333, saving model to /content\model.h5
Epoch 9/30
Epoch 9: val_loss improved from 0.57333 to 0.46258, saving model to /content\model.h5
Epoch 10/30
Epoch 10: val_loss improved from 0.46258 to 0.34142, saving model to /content\model.h5
Epoch 11/30
Epoch 11: val_loss improved from 0.34142 to 0.15539, saving model to /content\model.h5
Epoch 12/30

Epoch 18/30
Epoch 18: val_loss improved from 0.11227 to 0.10920, saving model to /content\model.h5
Epoch 19/30
Epoch 19: val_loss improved from 0.10920 to 0.10827, saving model to /content\model.h5
Epoch 20/30
Epoch 20: val_loss improved from 0.10827 to 0.10686, saving model to /content\model.h5
Epoch 21/30
Epoch 21: val_loss did not improve from 0.10686
Epoch 22/30
Epoch 22: val_loss improved from 0.10686 to 0.10377, saving model to /content\model.h5
Epoch 23/30
Epoch 23: val_loss did not improve from 0.10377
Epoch 24/30
Epoch 24: val_loss did not improve from 0.10377
Epoch 25/30
Epoch 25: val_loss did not improve from 0.10377
Epoch 26/30
Epoch 26: val_loss improved from 0.10377 to 0.10025, saving model to /content\model.h5
Epoch 27/30
Epoch 27: val_loss improved from 0.10025 to 0.09861, saving model to /content\model.h5
Epoch 28/30
Epoch 28: val_loss did not improve from 0.09861
Epoch 29/30
Epoch 29: val_loss improved from 0.09861 to 0.09717, saving model to /content\model.h5
Epoch 3

In [None]:
from tensorflow.keras.utils import CustomObjectScope

with CustomObjectScope({'iou': iou, 'dice_coef': dice_coef, 'dice_loss': dice_loss}):
  '''
  Under a scope with custom_object_scope(objects_dict), 
  Keras methods such as tf.keras.models.load_model will be able to deserialize any custom object 
  referenced by a saved config (e.g. a custom layer or metric).
  '''
  model = tf.keras.models.load_model("/content/model.h5")

In [None]:
""" Predicting the mask """
from tqdm import tqdm
import matplotlib.pyplot as plt

ct=0
for x, y_l, y_r in tqdm(zip(test_x, test_y_l, test_y_r), total=len(test_x)):
    """ Extracing the image name. """
    image_name = x.split("/")[-1]

    """ Reading the image """
    ori_x = cv2.imread(x, cv2.IMREAD_COLOR)
    ori_x = cv2.resize(ori_x, (512, 512))
    x = ori_x/255.0
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=0)

    """ Reading the mask """
    ori_y_l = cv2.imread(y_l, cv2.IMREAD_GRAYSCALE)
    ori_y_r = cv2.imread(y_r, cv2.IMREAD_GRAYSCALE)
    ori_y = ori_y_l + ori_y_r
    ori_y = cv2.resize(ori_y, (512, 512))
    ori_y = np.expand_dims(ori_y, axis=-1)  ## (512, 512, 1)
    ori_y = np.concatenate([ori_y, ori_y, ori_y], axis=-1)  ## (512, 512, 3)

    """ Predicting the mask. """
    y_pred = model.predict(x)[0] > 0.5
    y_pred = y_pred.astype(np.int32)
    #plt.imshow(y_pred)

    """ Saving the predicted mask along with the image and GT """
    save_image_path = str(ct)+".png"
    ct+=1
    y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1)
    sep_line = np.ones((512, 10, 3)) * 255
    cat_image = np.concatenate([ori_x, sep_line, ori_y, sep_line, y_pred*255], axis=1)
    cv2.imwrite(save_image_path, cat_image)

  0%|                                                                                           | 0/13 [00:00<?, ?it/s]



  8%|██████▍                                                                            | 1/13 [00:00<00:06,  1.77it/s]



 15%|████████████▊                                                                      | 2/13 [00:00<00:04,  2.63it/s]



 23%|███████████████████▏                                                               | 3/13 [00:01<00:03,  3.00it/s]



 31%|█████████████████████████▌                                                         | 4/13 [00:01<00:02,  3.19it/s]



 38%|███████████████████████████████▉                                                   | 5/13 [00:01<00:02,  3.35it/s]



 46%|██████████████████████████████████████▎                                            | 6/13 [00:01<00:02,  3.43it/s]



 54%|████████████████████████████████████████████▋                                      | 7/13 [00:02<00:01,  3.60it/s]



 62%|███████████████████████████████████████████████████                                | 8/13 [00:02<00:01,  3.48it/s]



 69%|█████████████████████████████████████████████████████████▍                         | 9/13 [00:02<00:01,  3.42it/s]



 77%|███████████████████████████████████████████████████████████████                   | 10/13 [00:03<00:00,  3.47it/s]



 85%|█████████████████████████████████████████████████████████████████████▍            | 11/13 [00:03<00:00,  3.43it/s]



 92%|███████████████████████████████████████████████████████████████████████████▋      | 12/13 [00:03<00:00,  3.43it/s]



100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [00:03<00:00,  3.27it/s]
