### Semantic Segmentation of Brain Abnormalities

In this notebook, I want to challenge myself by implementing a Unet model to segment MR images of the brain abnormalities pixel by pixel. The dataset I will be using contains brain MR images together with manual FLAIR abnormality segmentation masks obtained from The Cancer Imaging Archive (TCIA).
Images correspond to 110 patients included in The Cancer Genome Atlas (TCGA) lower-grade glioma collection with at least fluid-attenuated inversion recovery (FLAIR) sequence and genomic cluster data available. (Data source: [The Cancer Imagin Archrive](https://wiki.cancerimagingarchive.net/display/Public/TCGA-LGG#6abaca285cee4c9cac59b0bcff944658))


#### Dependencies (Required Frameworks):

In [None]:
import os
import numpy as np
import random 
import datetime
import matplotlib.pyplot as plt
from tqdm import tqdm

# TensorFlow
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import TensorBoard

tf.keras.backend.backend()

# Imaging frameworks
from PIL import Image
from skimage.transform import rotate, rescale
import skimage.io
import cv2

#### Cleaning and Loading the data:

For each patient MR Images and segmentation Masks have been stored in the same directory. Masks have the same name as images plus a "_mask" string.

First, we remove the unused files and save Images and Masks in 2 different directories.

In [None]:
# The directory of the image dataset
Dataset_path = "~/tmp/TCGA-LGG" 

# Access to the directory from Python and Remove unused data
files=os.listdir(Dataset_path) 
folders.remove('data.csv') 
folders.remove('README.md') 

# Masks and Images temporary paths
temp_image_path, temp_mask_path= [], [] 

# Reading folders containing Images and Masks for each patient
for folder in tqdm(folders): 
    img_mask_path = os.path.join(Dataset_path, folder)
    files = os.listdir(img_mask_path) 
    
# Saving Images and corresponding Masks in separate directories    
for file in tqdm(files): 
    if "mask.tif" not in file.split('_'): 
        temp_image_path.append(os.path.join(Dataset_path, file))
    else:
        temp_mask_path.append(os.path.join(Dataset_path, file))

temp_image_path and temp_mask_path include MR Images and segmentation Masks for all patients.

If there are images without masks,the below code will delete them.

In [None]:
img_path, mask_path = [], [] #Masks and images paths
img_wo_mask = [] #Images without masks paths

# If an Image matches a Mask, parallelly store them in two separate paths 
for img in tqdm(temp_image_path):
    img_mask = img.split(".")[0] + "_mask" 
    for mask in temp_mask_path:
        if img_mask == mask.split(".")[0]:
            img_path.append(img), mask_path.append(mask) 
        else:
            continue
            
# If the len(img_path) != len(mask_path), add image to img_wo_mask           
if len(img_path) == len(temp_image_path): 
    print('\033[1m' + "All images have mask!")    
else: 
    for img in temp_image_path:
        if img not in img_path:
            img_wo_mask. append(img)
            print('\033[1m' + img + " does not have mask!")

Spliting the data into Train, Validate and Test data sets!

In [None]:
def datasetdir(directory):
    try:
        os.mkdir(directory)
        print('\033[1m' + "Directory " , directory ,  " Created ") 
    except FileExistsError:
        print('\033[1m' + "Directory " , directory ,  " already exists")

train_dir = "~/tmp/train"        
val_dir = "~/tmp/validate"
test_dir = "~/tmp/test"

train_dir = datasetdir(train_dir)
val_dir = datasetdir(val_dir)
test_dir = datasetdir(test_dir)

Split the data inside the img_path and msk_path into three directories of train_dir, val_dir and test_dir.

In [None]:
random.seed(2020)
random.shuffle(img_path) #shuffle images
random.shuffle(mask_path) #shuffle masks

train_split = int(0.7*len(img_path)) #train: 70%
val_split = int(0.9 * len(img_path)) #val: 20% & test: 10% 

def dataset(img_path, img_dir, split1, split2):
    images = img_path[split1:split2]
    for i in tqdm(images): Image.open(i).save(img_dir + "\\{}".format(i.split("\\")[-1])) 
    return images

train_images = dataset(img_path, train_dir, None, train_split)
train_masks = dataset(img_path, train_dir, None, train_split)
val_images = dataset(img_path, val_dir, train_split, val_split)
val_masks = dataset(img_path, val_dir, train_split, val_split)
test_images = dataset(img_path, test_dir, val_split, None)
test_masks = dataset(img_path, test_dir, val_split, None)

#### Image Augmentation:

The performance of deep learning neural networks often improves with the amount of data available.

Data augmentation is a technique to artificially create new training data from existing training data. This is done by applying domain-specific techniques to examples from the training data that create new and different training examples.

Image data augmentation is perhaps the most well-known type of data augmentation and involves creating transformed versions of images in the training dataset that belong to the same class as the original image.

Transforms include a range of operations from the field of image manipulation, such as shifts, flips, zooms, and much more.

In this section, we will use two techniques including rotating and scaling images.

In [None]:
def rotate_img(img, img_mask):
    """ Rotating images randomly from 5.0 to 15.0 units """
    
    angle = np.random.uniform(5.0, 15.0) * np.random.choice([-1.0, 1.0], 1)[0]

    img = rotate(img, angle, resize=False, order=3, preserve_range=True)
    img_mask = rotate(img_mask, angle, resize=False, order=0, preserve_range=True)

    return img, img_mask

def scale_img(img, img_mask, img_height = 256, img_width = 256): 
    """ Scaling images randomly between 0.04 and 0.08 units"""
    
    scale = 1.0 + np.random.uniform(0.04, 0.08) * np.random.choice([-1.0, 1.0], 1)[0]

    img = rescale(img, scale, order=3, preserve_range=True)
    img_mask = rescale(img_mask, scale, order=0, preserve_range=True)
    if scale > 1:
        img = center_crop(img, img_height, img_width)
        img_mask = center_crop(img_mask, img_height, img_width)
    else:
        img = zeros_pad(img, img_height)
        img_mask = zeros_pad(img_mask, img_height)

    return img, img_mask

def center_crop(img, cropx, cropy):
    """ Cropping the center of images """
    
    x = img.shape[1] // 2 - (cropx // 2)
    y = img.shape[0] // 2 - (cropy // 2)
    return img[y : y + cropy, x : x + cropx]

def zeros_pad(img, size): 
    """ Zero-padding images """
    
    pad_before = int(round(((size - img.shape[0]) / 2.0)))
    pad_after = size - img.shape[0] - pad_before
    if len(img.shape) > 2:
        return np.pad(img, ((pad_before, pad_after), (pad_before, pad_after), (0, 0)), mode="constant")
    return np.pad(img, (pad_before, pad_after), mode="constant")

In [None]:
def aug(images, masks):

    images_augmentation = []
    masks_augmentation = []
    
    # Augment images if the corresponding segmentation Mask is not None
    for i in tqdm(range(len(images))): 
        if np.max(masks[i]) < 1:
            continue
            
        #rotating
        image_rotate, mask_rotate = rotate_img(images[i], masks[i]) 
        images_augmentation.append(image_rotate)
        masks_augmentation.append(mask_rotate)
        
        #scaling
        image_scale, mask_scale = scale_img(images[i], masks[i]) 
        images_augmentation.append(image_scale)
        masks_augmentation.append(mask_scale)
            
        # Duplicate the data if the segmentation mask is not None
        for _ in range(2): 
            images_augmentation.append(images[i])
            masks_augmentation.append(masks[i])

    images_augmentation = np.array(images_augmentation)
    masks_augmentation = np.array(masks_augmentation)
    
    return np.vstack((images, images_augmentation)), np.vstack((masks, masks_augmentation)) #add images and augmented images together

#### Model's input data:

As the final step, we read the data as matrices of numbers ready to feed into the model!

In [None]:
def data(path, img_height = 256, img_width = 256, channels = 3, augmentation = True):
    
    images_list = os.listdir(path)
    
    total_count = int(len(images_list) / 2)
    
    images = np.ndarray((total_count, img_height, img_width), dtype=np.uint8) 
    masks = np.ndarray((total_count, img_height, img_width), dtype=np.uint8) 
    names = np.chararray(total_count, itemsize=64)

    i = 0
    for image_name in tqdm(images_list):
        if "mask" in image_name:
            continue

        names = image_name.split(".")[0]
        slice_number = int(names.split("_")[-1])
        patient_id = "_".join(names.split("_")[:-1])

        img = skimage.io.imread(os.path.join(path, image_name), as_gray=True) #read the image
        images[i] = img
        
        image_mask_name = image_name.split(".")[0] + "_mask.tif"
        img_mask = skimage.io.imread(os.path.join(path, image_mask_name), as_gray=True) #read the mask
        img_mask = cv2.resize(img_mask, (128, 128), interpolation=cv2.INTER_NEAREST)
        img_mask = np.array([img_mask])
        masks[i] = img_mask
        
        i +=1
        
    images = images[..., np.newaxis]
    images = images.astype("float32")
    mean = np.mean(images)
    std = np.std(images)
    images -= mean
    images /= std
    
    masks = masks[..., np.newaxis]
    masks = masks.astype("float32")
    masks /= 255.
    
    if augmentation == True: 
        images, masks = aug(images, masks)

    return images, masks

#### Unet Model:

The [**UNET**](https://arxiv.org/abs/1505.04597) was developed by Olaf Ronneberger et al. for Bio Medical Image Segmentation. The architecture contains two paths. First path is the contraction path (also called as the encoder) which is used to capture the context in the image. The encoder is just a traditional stack of convolutional and max pooling layers. The second path is the symmetric expanding path (also called as the decoder) which is used to enable precise localization using transposed convolutions. Thus it is an end-to-end fully convolutional network (FCN), i.e. it only contains Convolutional layers and does not contain any Dense layer because of which it can accept image of any size.

In the original paper, the **UNET** is described as follows:

![U-net](https://miro.medium.com/max/680/1*TXfEPqTbFBPCbXYh2bstlA.png)


In [None]:
def network(img_height = 256, img_width = 256, channels = 3):
    
    """UNET structure implemented using TensorFlow"""
    
    inputs = Input((img_height, img_width, channels)) #Input image information
    
    #Filters = [32, 64, 128, 256, 512]
    
    conv1 = Conv2D(32, (3, 3), padding='same')(inputs) #3 by 3 convolution with same padding
    conv1 = Activation('relu')(conv1) #relu activation
    conv1 = Conv2D(32, (3, 3), padding='same')(conv1) #3 by 3 convolution with same padding
    conv1 = Activation('relu')(conv1) #relu activation
    Max_pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) #Maxpooling filter 
    
    conv2 = Conv2D(64, (3, 3), padding='same')(Max_pool1) #3 by 3 convolution with same padding
    conv2 = Activation('relu')(conv2) #relu activation
    conv2 = Conv2D(64, (3, 3), padding='same')(conv2) #3 by 3 convolution with same padding
    conv2 = Activation('relu')(conv2) #relu activation
    Max_pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) #Maxpooling filter
    
    conv3 = Conv2D(128, (3, 3), padding='same')(Max_pool2) #3 by 3 convolution with same padding
    conv3 = Activation('relu')(conv3) #relu activation
    conv3 = Conv2D(128, (3, 3), padding='same')(conv3) #3 by 3 convolution with same padding
    conv3 = Activation('relu')(conv3) #relu activation
    Max_pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) #Maxpooling filter
    
    conv4 = Conv2D(256, (3, 3), padding='same')(Max_pool3) #3 by 3 convolution with same padding
    conv4 = Activation('relu')(conv4) #relu activation
    conv4 = Conv2D(256, (3, 3), padding='same')(conv4) #3 by 3 convolution with same padding
    conv4 = Activation('relu')(conv4) #relu activation
    Max_pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) #Maxpooling filter
    
    conv5 = Conv2D(512, (3, 3), padding='same')(Max_pool4) #3 by 3 convolution with same padding
    conv5 = Activation('relu')(conv5) #relu activation
    conv5 = Conv2D(512, (3, 3), padding='same')(conv5) #3 by 3 convolution with same padding
    conv5 = Activation('relu')(conv5) #relu activation
    
    upconv1 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5) #2 by 2 upconvolution with (2, 2) strides
    concat1 = concatenate([upconv1, conv4], axis=3)
    
    conv6 = Conv2D(256, (3, 3), padding='same')(concat1) #3 by 3 convolution with same padding
    conv6 = Activation('relu')(conv6) #relu activation
    conv6 = Conv2D(256, (3, 3), padding='same')(conv6) #3 by 3 convolution with same padding
    conv6 = Activation('relu')(conv6) #relu activation
    
    upconv2 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6) #2 by 2 upconvolution with (2, 2) strides
    concat2 = concatenate([upconv2, conv3], axis=3) #Adding the encoder to decoder output
    
    conv7 = Conv2D(128, (3, 3), padding='same')(concat2) #3 by 3 convolution with same padding
    conv7 = Activation('relu')(conv7) #relu activation
    conv7 = Conv2D(128, (3, 3), padding='same')(conv7) #3 by 3 convolution with same padding
    conv7 = Activation('relu')(conv7) #relu activation
    
    upconv3 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7) #2 by 2 upconvolution with (2, 2) strides
    concat3 = concatenate([upconv3, conv2], axis=3) #Adding the encoder to decoder output
    
    conv8 = Conv2D(64, (3, 3), padding='same')(concat3) #3 by 3 convolution with same padding
    conv8 = Activation('relu')(conv8) #relu activation
    conv8 = Conv2D(64, (3, 3), padding='same')(conv8) #3 by 3 convolution with same padding
    conv8 = Activation('relu')(conv8) #relu activation
    
    upconv4 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8) #2 by 2 upconvolution with (2, 2) strides
    concat4 = concatenate([upconv4, conv1], axis=3) #Adding the encoder to decoder output
    
    conv9 = Conv2D(32, (3, 3), padding='same')(concat4) #3 by 3 convolution with same padding
    conv9 = Activation('relu')(conv9) #relu activation
    conv9 = Conv2D(32, (3, 3), padding='same')(conv9) #3 by 3 convolution with same padding
    conv9 = Activation('relu')(conv9) #relu activation
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(conv9) #1 by 1 convolution with sigmoid activation

    model = Model(inputs=[inputs], outputs=[outputs]) #The output model
    
    return model

#### Loss Function:

The Dice similarity coefficient, also known as the Sørensen–Dice index or simply Dice coefficient, is a statistical tool which measures the similarity between two sets of data. This index has become arguably the most broadly used tool in the validation of image segmentation algorithms created with AI, but it is a much more general concept which can be applied sets of data for a variety of applications including NLP.

The equation for this concept is:
2 * |X ∩ Y| / (|X| + |Y|)

where X and Y are two sets
a set with vertical bars either side refers to the cardinality of the set, i.e. the number of elements in that set, e.g. |X| means the number of elements in set X
∩ is used to represent the intersection of two sets, and means the elements that are common to both sets

In [None]:
def dice_coef(true, pred):
    flatten_true = tf.keras.backend.flatten(true)
    flatten_pred = tf.keras.backend.flatten(pred)
    intersection = tf.keras.backend.sum(flatten_true * flatten_pred)
    
    return (2. * intersection + 1.0) / (tf.keras.backend.sum(flatten_true) + 
                                        tf.keras.backend.sum(flatten_pred) + 1.0) #Smooth = 1.0

def dice_coef_loss(true, pred):
    return 1.0 - dice_coef(true, pred)

#### Compiling the model:

In [None]:
model = network(img_height = 256, img_width = 256, channels = 1)
optimizer = tf.keras.optimizers.Adam(lr=1e-5) #Adam optimizer with 1e-5 learning rate
model.compile(optimizer=optimizer, loss=dice_coef_loss, metrics=[dice_coef])

model.summary()

#### Training the model:

In [None]:
weights_path = "~/tmp/training_weights"
train_path = train_dir
val_path = val_dir 
batch_size = 16
epochs = 30

def train(train_images, train_masks, val_images, val_masks, weights_path):
    """ Train the network using the train and val data """
    
    model.fit(
        train_images,
        train_masks,
        validation_data=(val_images, val_masks),
        batch_size=batch_size,
        epochs=epochs,
        shuffle=True)

    #Save training weights
    if not os.path.exists(weights_path):
        os.mkdir(weights_path)
    model.save_weights(os.path.join(weights_path, "weights_{}.h5".format(epochs)))

In [None]:
train_images, train_masks = data(train_path, img_height = 256, img_width = 256, channels = 3, augmentation = False)
val_images, val_masks = data(val_path, img_height = 256, img_width = 256, channels = 3, augmentation = False)

In [None]:
train(train_images, train_masks, val_images, val_masks, weights_path)

#### Predicting and plotting unseen data:

In [None]:
mode_weight = "~/tmp/weights_30.h5"

def predict(test_path, model_weight):
    """ Predict Masks for the input Images """
    
    model = network(img_height = 256, img_width = 256, channels = 3)    
    model.load_weights(model_weight)

    # make predictions
    pred_masks = model.predict(test_images, verbose=1)

    

    return test_images, test_masks, pred_masks

In [None]:
test_images, test_masks, pred_masks = predict(test_dir, model_weight)

In [None]:
def myshow(image, squeeze, rgb2gray):
    """ myshow plot the MR Image, Ground Truth Mask and Predicted Mask """
    
    if rgb2gray == True:
        img = skimage.color.rgb2gray(image) #Rgb2Gray test images
        
    if squeeze == True: 
        img = np.squeeze(image) #Squeeze Mask images

    return img

print('\033[1m' + "\n    Original Image            Ground Truth Mask           Predicted Mask")

for i in range(len(test_images)): 
        if np.max(test_masks[i]) < 1:
            continue
        
        #Plotting images
        fig = plt.figure(figsize=(10,10))
        fig.subplots_adjust(hspace=0.1, wspace=0.1)
        ax = fig.add_subplot(1,3,1)
        ax.imshow(myshow(test_images[i], squeeze=False, rgb2gray=True), cmap="gray")
        ax.axis(False)
        ax = fig.add_subplot(1,3,2)
        ax.imshow(myshow(test_masks[i], squeeze=True, rgb2gray=False), cmap="gray")
        ax.axis(False)
        ax = fig.add_subplot(1,3,3)
        ax.imshow(myshow(pred_masks[i], squeeze=True, rgb2gray=False), cmap="gray")
        ax.axis(False)