<a href="https://colab.research.google.com/github/albertofernandezvillan/ml-dl-cv-notebooks/blob/main/segmentation_unet_keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image Segmentation with Keras and UNET

TODO: include some explanation from the following links:
*   https://stackoverflow.com/questions/47435526/what-is-the-meaning-of-axis-1-in-keras-argmax
*   https://towardsdatascience.com/how-to-create-a-custom-loss-function-keras-3a89156ec69b
*   https://www.tensorflow.org/guide/tensor
*   https://colab.research.google.com/github/MarkDaoust/models/blob/segmentation_blogpost/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb


**Segmentation** is the process of generating pixel-wise segmentations giving the class of the object visible at each pixel.

We will be using the [Kaggle Carvana Image Masking Challenge Dataset](https://www.kaggle.com/c/carvana-image-masking-challenge). 

This dataset contains a large number of car images, with each car taken from different angles. In addition, for each car image, we have an associated manually cutout mask; our task will be to automatically create these cutout masks for unseen data. 

## Specific concepts that will be covered:
* **[Functional API](https://keras.io/getting-started/functional-api-guide/)** - we will be implementing UNet, a convolutional network model classically used for biomedical image segmentation with the Functional API. 
  * This model has layers that require multiple input/outputs. This requires the use of the functional API
  * Check out the original [paper](https://arxiv.org/abs/1505.04597), 
U-Net: Convolutional Networks for Biomedical Image Segmentation by Olaf Ronneberger!
* **Custom Loss Functions and Metrics** - We'll implement a custom loss function using binary [**cross entropy**](https://developers.google.com/machine-learning/glossary/#cross-entropy) and **dice loss**. We'll also implement **dice coefficient** (which is used for our loss) and **mean intersection over union**, that will help us monitor our training process and judge how well we are performing. 
* **Saving and loading keras models** - We'll save our best model to disk. When we want to perform inference/evaluate our model, we'll load in the model from disk. 

In [None]:
import os
import glob
import zipfile
import functools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd
from PIL import Image

In [None]:
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import backend as K
from tensorflow.data import AUTOTUNE

# Download data from Kaggle competition

1. Install the latest version of Kaggle API
2. Upload the file kaggle.json to Colab to the required directory
3. Test the API is working
4. Download required files from the competition

In [None]:
# Install kaggle API
! pip install kaggle
# Force install the latest version
! pip install --upgrade --force-reinstall --no-deps kaggle

In [None]:
# Upload kaggle.json file to Colab
from google.colab import files
uploaded = files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! rm kaggle.json

In [None]:
# Check the API is working OK:
!kaggle competitions list

In [None]:
# If you want to download all the files from the competition:
# !kaggle competitions download -c carvana-image-masking-challenge

In [None]:
# In this case we are going to use only train files:
# We download required files and unzip them inside the database folder:
! mkdir carvana_dataset

! kaggle competitions download -f train.zip -c carvana-image-masking-challenge -p carvana_dataset
! kaggle competitions download -f train_masks.csv.zip -c carvana-image-masking-challenge -p carvana_dataset
! kaggle competitions download -f train_masks.zip -c carvana-image-masking-challenge -p carvana_dataset

%cd carvana_dataset

! unzip /content/carvana_dataset/train.zip
! unzip /content/carvana_dataset/train_masks.csv.zip
! unzip /content/carvana_dataset/train_masks.zip

In [None]:
%cd /content/

In [None]:
# Get de IDs of the images (and masks)
root_folder = "carvana_dataset"

img_dir = os.path.join(root_folder, "train")
label_dir = os.path.join(root_folder, "train_masks")
df_train = pd.read_csv(os.path.join(root_folder, 'train_masks.csv'))

ids_train = df_train['img'].map(lambda s: s.split('.')[0])
print(ids_train[:10])

In [None]:
x_train_filenames = []
y_train_filenames = []

for img_id in ids_train:
  x_train_filenames.append(os.path.join(img_dir, "{}.jpg".format(img_id)))
  y_train_filenames.append(os.path.join(label_dir, "{}_mask.gif".format(img_id)))

In [None]:
x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = \
                    train_test_split(x_train_filenames, y_train_filenames, test_size=0.15, random_state=42)

In [None]:
num_train_examples = len(x_train_filenames)
num_val_examples = len(x_val_filenames)

print("Number of training examples: {}".format(num_train_examples))
print("Number of validation examples: {}".format(num_val_examples))

print("Some sample paths:")
print(x_train_filenames[:10])
print(y_train_filenames[:10])

# Exploring the dataset

Let's take a look at some of the examples of different images in our dataset. 

In [None]:
num_display = 4

# Generate a uniform random sample from np.arange(num_train_examples) of size display_num:
r_choices = np.random.choice(num_train_examples, num_display)

plt.figure(figsize=(10, 15))
for i in range(0, num_display):
  index = r_choices[i]
  x_pathname = x_train_filenames[index]
  y_pathname = y_train_filenames[index]

  img = Image.open(x_pathname)  
  plt.subplot(num_display, 2, i * 2 + 1)
  plt.imshow(img)
  plt.title("Original Image")
  
  mask = Image.open(y_pathname)  
  plt.subplot(num_display, 2, i * 2 + 2)
  plt.imshow(mask)
  plt.title("Masked Image")  
  
plt.suptitle("Examples of Images and their Masks")
plt.show()

# Set up 

Let’s begin by setting up some parameters. We’ll standardize and resize all the shapes of the images. We’ll also set up some training parameters, with some notes:

* Due to the architecture of our UNet version, the size of the image must be evenly divisible by a factor of `32`, as we down sample the spatial resolution by a factor of `2` with each `MaxPooling2Dlayer`.
* If your machine can support it, you will achieve better performance using a higher resolution input image (e.g. 512 by 512) as this will allow more precise localization and less loss of information during encoding. In addition, you can also make the model deeper.

In [None]:
img_shape = (256, 256, 3)
batch_size = 3
epochs = 5

# Build our input pipeline with `tf.data`
Since we begin with filenames, we will need to build a robust and scalable data pipeline that will play nicely with our model. 

Data augmentation "increases" the amount of training data by augmenting them via a number of random transformations. During training time, our model would never see twice the exact same picture. This helps prevent overfitting and helps the model generalize better to unseen data.

Our input pipeline will consist of the following steps:
1. Read the bytes of the file in from the filename - for both the image and the label. Recall that our labels are actually images with each pixel annotated as car or background (1, 0). 
2. Decode the bytes into an image format
3. Apply image transformations: (optional, according to input parameters)
  * `resize` - Resize our images to a standard size (as determined by eda or computation/memory restrictions)
    * The reason why this is optional is that U-Net is a fully convolutional network (e.g. with no fully connected units) and is thus not dependent on the input size. However, if you choose to not resize the images, you must use a batch size of 1, since you cannot batch variable image size together
    * Alternatively, you could also bucket your images together and resize them per mini-batch to avoid resizing images as much, as resizing may affect your performance through interpolation, etc.
  * `hue_delta` - Adjusts the hue of an RGB image by a random factor. This is only applied to the actual image (not our label image). The `hue_delta` must be in the interval `[0, 0.5]` 
  * `horizontal_flip` - flip the image horizontally along the central axis with a 0.5 probability. This transformation must be applied to both the label and the actual image. 
  * `width_shift_range` and `height_shift_range` are ranges (as a fraction of total width or height) within which to randomly translate the image either horizontally or vertically. This transformation must be applied to both the label and the actual image. 
  * `rescale` - rescale the image by a certain factor, e.g. 1/ 255.
4. Shuffle the data, repeat the data (so we can iterate over it multiple times across epochs), batch the data, then prefetch a batch (for efficiency).

It is important to note that these transformations that occur in your data pipeline must be symbolic transformations. 

## Processing each pathname

In [None]:
def _process_pathnames(fname, label_path):
  # 1. Process the image 
  img_str = tf.io.read_file(fname)
  img = tf.image.decode_jpeg(img_str, channels=3)  # tf.Tensor([1280 1918 3],shape=(3,),dtype=int32)
  
  # 2. For processing the label: 
  label_img_str = tf.io.read_file(label_path) 
  # 2.a) load and get the first image of the GIF
  # These are gif images so they return as (num_frames, h, w, c)
  # In this case, for example: tf.Tensor([1 1280 1918 3],shape=(4,),dtype=int32)
  label_img = tf.image.decode_gif(label_img_str)[0] # tf.Tensor([1280 1918 3],shape=(3,),dtype=int32)

  # 2.b) We take the first channel only. 
  label_img = label_img[:, :, 0] # tf.Tensor([1280 1918],shape=(2,),dtype=int32)
  label_img = tf.expand_dims(label_img, axis=-1) # tf.Tensor([1280 1918 1],shape=(3,),dtype=int32)

  # These last two steps can also be performed as follows:
  # label_img = label_img[:, :, 0:1] # tf.Tensor([1280 1918 1],shape=(3,),dtype=int32)

  return img, label_img

In [None]:
# Check that the coded function above works OK
# Also check the dimension of each tensor
img, label_img = _process_pathnames(x_train_filenames[0], y_train_filenames[0])
print(tf.shape(img)) # tf.Tensor([1280 1918 3], shape=(3,), dtype=int32)
print(tf.shape(label_img)) # tf.Tensor([1280 1918 1], shape=(3,), dtype=int32)

In [None]:
# Show both the image and the mask to confirm that the previous function is 
# working as expected
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, i * 2 + 1)
plt.imshow(img.numpy())
plt.title("Original Image")
  
plt.subplot(1, 2, i * 2 + 2)
plt.imshow(label_img.numpy().reshape((label_img.shape[0], label_img.shape[1])))
plt.title("Masked Image")  
  
plt.suptitle("Examples of Images and their Masks")
plt.show()

## Shifting the image

In [None]:
!pip install tensorflow-addons

In [None]:
import tensorflow_addons as tfa

In [None]:
def shift_img(output_img, label_img, width_shift_range, height_shift_range):
  """This fn will perform the horizontal or vertical shift"""
  if width_shift_range or height_shift_range:
      if width_shift_range:
        width_shift_range = tf.random.uniform([], 
                                              -width_shift_range * img_shape[1],
                                              width_shift_range * img_shape[1])
      if height_shift_range:
        height_shift_range = tf.random.uniform([],
                                               -height_shift_range * img_shape[0],
                                               height_shift_range * img_shape[0])
      
      output_img = tfa.image.translate(output_img,
                                       [width_shift_range, height_shift_range])
      label_img = tfa.image.translate(label_img,
                                      [width_shift_range, height_shift_range])
  return output_img, label_img

## Flipping the image randomly 

In [None]:
def flip_img(horizontal_flip, tr_img, label_img):
  if horizontal_flip:
    flip_prob = tf.random.uniform([], 0.0, 1.0)
    tr_img, label_img = tf.cond(tf.less(flip_prob, 0.5),
                                lambda: (tf.image.flip_left_right(tr_img), tf.image.flip_left_right(label_img)),
                                lambda: (tr_img, label_img))
  return tr_img, label_img

## Assembling our transformations into our augment function

In [None]:
@tf.function
def _augment(img,
             label_img,
             resize=None,  # Resize the image to some size e.g. [256, 256]
             scale=1,  # Scale image e.g. 1 / 255.
             hue_delta=0,  # Adjust the hue of an RGB image by random factor
             horizontal_flip=False,  # Random left right flip,
             width_shift_range=0,  # Randomly translate the image horizontally
             height_shift_range=0):  # Randomly translate the image vertically 
  if resize is not None:
    label_img = tf.image.resize(label_img, resize)
    img = tf.image.resize(img, resize)
  
  if hue_delta:
    img = tf.image.random_hue(img, hue_delta)
  
  img, label_img = flip_img(horizontal_flip, img, label_img)
  img, label_img = shift_img(img, label_img, width_shift_range, height_shift_range)

  label_img = tf.cast(label_img, dtype=tf.float32) * scale
  img = tf.cast(img, dtype=tf.float32) * scale  

  return img, label_img

In [None]:
tr_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'scale': 1 / 255.,
    'hue_delta': 0.1,
    'horizontal_flip': True,
    'width_shift_range': 0.1,
    'height_shift_range': 0.1
}
tr_preprocessing_fn = functools.partial(_augment, **tr_cfg)

In [None]:
train_data_set = tf.data.Dataset.from_tensor_slices((x_train_filenames, y_train_filenames))
train_data_set = (train_data_set
                  .shuffle(len(x_train_filenames))
                  .map(_process_pathnames, num_parallel_calls = AUTOTUNE)
                  .map(tr_preprocessing_fn, num_parallel_calls = AUTOTUNE)
                  .cache()
                  .batch(batch_size)
                  .prefetch(AUTOTUNE)
                  )

In [None]:
val_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'scale': 1 / 255.,
}
val_preprocessing_fn = functools.partial(_augment, **val_cfg)

In [None]:
val_data_set = tf.data.Dataset.from_tensor_slices((x_val_filenames, y_val_filenames))
val_data_set = (val_data_set
                .map(_process_pathnames, num_parallel_calls = AUTOTUNE)
                .map(val_preprocessing_fn, num_parallel_calls = AUTOTUNE)
                .cache()
                .batch(batch_size)
                .prefetch(AUTOTUNE)
                )

In [None]:
for next_element in train_data_set:
    batch_of_imgs, label = next_element
    plt.figure(figsize=(10, 10))
    img = batch_of_imgs[0]

    plt.subplot(1, 2, 1)
    plt.imshow(img)

    plt.subplot(1, 2, 2)
    plt.imshow(label[0, :, :, 0])
    plt.show()
    break

In [None]:
for next_element in val_data_set:
    batch_of_imgs, label = next_element
    plt.figure(figsize=(10, 10))
    img = batch_of_imgs[0]

    plt.subplot(1, 2, 1)
    plt.imshow(img)

    plt.subplot(1, 2, 2)
    plt.imshow(label[0, :, :, 0])
    plt.show()
    break

# Build the model
We'll build the U-Net model. U-Net is especially good with segmentation tasks because it can localize well to provide high resolution segmentation masks. In addition, it works well with small datasets and is relatively robust against overfitting as the training data is in terms of the number of patches within an image, which is much larger than the number of training images itself. Unlike the original model, we will add batch normalization to each of our blocks. 

The Unet is built with an encoder portion and a decoder portion. The encoder portion is composed of a linear stack of [`Conv`](https://developers.google.com/machine-learning/glossary/#convolution), `BatchNorm`, and [`Relu`](https://developers.google.com/machine-learning/glossary/#ReLU) operations followed by a [`MaxPool`](https://developers.google.com/machine-learning/glossary/#pooling). Each `MaxPool` will reduce the spatial resolution of our feature map by a factor of 2. We keep track of the outputs of each block as we feed these high resolution feature maps with the decoder portion. The Decoder portion is comprised of UpSampling2D, Conv, BatchNorm, and Relus. Note that we concatenate the feature map of the same size on the decoder side. Finally, we add a final Conv operation that performs a convolution along the channels for each individual pixel (kernel size of (1, 1)) that outputs our final segmentation mask in grayscale. 
## The Keras Functional API
The Keras functional API is used when you have multi-input/output models, shared layers, etc. It's a powerful API that allows you to manipulate tensors and build complex graphs with intertwined datastreams easily. In addition it makes **layers** and **models** both callable on tensors. 
  * To see more examples check out the [get started guide](https://keras.io/getting-started/functional-api-guide/). 
  
  
  We'll build these helper functions that will allow us to ensemble our model block operations easily and simply. 

In [None]:
def conv_block(input_tensor, num_filters):
  encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
  encoder = layers.BatchNormalization()(encoder)
  encoder = layers.Activation('relu')(encoder)
  encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
  encoder = layers.BatchNormalization()(encoder)
  encoder = layers.Activation('relu')(encoder)
  return encoder

def encoder_block(input_tensor, num_filters):
  encoder = conv_block(input_tensor, num_filters)
  encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
  
  return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters):
  decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
  decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
  decoder = layers.BatchNormalization()(decoder)
  decoder = layers.Activation('relu')(decoder)
  decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
  decoder = layers.BatchNormalization()(decoder)
  decoder = layers.Activation('relu')(decoder)
  decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
  decoder = layers.BatchNormalization()(decoder)
  decoder = layers.Activation('relu')(decoder)
  return decoder

In [None]:
inputs = layers.Input(shape=img_shape)
# 256

encoder0_pool, encoder0 = encoder_block(inputs, 32)
# 128

encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
# 64

encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
# 32

encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
# 16

encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)
# 8

center = conv_block(encoder4_pool, 1024)
# center

decoder4 = decoder_block(center, encoder4, 512)
# 16

decoder3 = decoder_block(decoder4, encoder3, 256)
# 32

decoder2 = decoder_block(decoder3, encoder2, 128)
# 64

decoder1 = decoder_block(decoder2, encoder1, 64)
# 128

decoder0 = decoder_block(decoder1, encoder0, 32)
# 256

outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(decoder0)

## Define your model
Using functional API, you must define your model by specifying the inputs and outputs associated with the model. 

In [None]:
model = models.Model(inputs=[inputs], outputs=[outputs])

# Defining custom metrics and loss functions
Defining loss and metric functions are simple with Keras. Simply define a function that takes both the True labels for a given example and the Predicted labels for the same given example. 

Dice loss is a metric that measures overlap. More info on optimizing for Dice coefficient (our dice loss) can be found in the [paper](http://campar.in.tum.de/pub/milletari2016Vnet/milletari2016Vnet.pdf), where it was introduced. 

We use dice loss here because it performs better at class imbalanced problems by design. In addition, maximizing the dice coefficient and IoU metrics are the actual objectives and goals of our segmentation task. Using cross entropy is more of a proxy which is easier to maximize. Instead, we maximize our objective directly. 

In [None]:
def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score

In [None]:
def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

Here, we'll use a specialized loss function that combines binary cross entropy and our dice loss. This is based on [individuals who competed within this competition obtaining better results empirically](https://www.kaggle.com/c/carvana-image-masking-challenge/discussion/40199). 

In [None]:
def bce_dice_loss(y_true, y_pred):
    loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

## Compile your model
We use our custom loss function to minimize. In addition, we specify what metrics we want to keep track of as we train. Note that metrics are not actually used during the training process to tune the parameters, but are instead used to measure performance of the training process. 

In [None]:
model.compile(optimizer='adam', loss=bce_dice_loss, metrics=[dice_loss])

model.summary()

## Train your model
Training your model with `tf.data` involves simply providing the model's `fit` function with your training/validation dataset, the number of steps, and epochs.  

We also include a Model callback, [`ModelCheckpoint`](https://keras.io/callbacks/#modelcheckpoint) that will save the model to disk after each epoch. We configure it such that it only saves our highest performing model. Note that saving the model capture more than just the weights of the model: by default, it saves the model architecture, weights, as well as information about the training process such as the state of the optimizer, etc.

In [None]:
# save_model_path = '/tmp/weights.hdf5'
save_model_path = '/content/seg_unet_carvana_weights.hdf5'
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path, monitor='val_dice_loss', mode='max', save_best_only=True)

Don't forget to specify our model callback in the `fit` function call. 

In [None]:
history = model.fit(train_data_set, 
                   steps_per_epoch=int(np.ceil(num_train_examples / float(batch_size))),
                   epochs=epochs,
                   validation_data=val_data_set,
                   validation_steps=int(np.ceil(num_val_examples / float(batch_size))),
                   callbacks=[cp])

# Visualize training process

In [None]:
dice = history.history['dice_loss']
val_dice = history.history['val_dice_loss']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, dice, label='Training Dice Loss')
plt.plot(epochs_range, val_dice, label='Validation Dice Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Dice Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

Even with only 5 epochs, we see strong performance.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp  "/content/seg_unet_carvana_weights.hdf5" -r "/content/drive/MyDrive"

# Visualize actual performance 
We'll visualize our performance on the validation set.

Note that in an actual setting (competition, deployment, etc.) we'd evaluate on the test set with the full image resolution. 

To load our model we have two options:
1. Since our model architecture is already in memory, we can simply call `load_weights(save_model_path)`
2. If you wanted to load the model from scratch (in a different setting without already having the model architecture in memory) we simply call 

```model = models.load_model(save_model_path, custom_objects={'bce_dice_loss': bce_dice_loss, 'mean_iou': mean_iou,'dice_coeff': dice_coeff})```, specificing the necessary custom objects, loss and metrics, that we used to train our model. 

If you want to see more examples, check our the [keras guide](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model)!

In [None]:
save_model_path = "/content/seg_unet_carvana_weights.hdf5"

In [None]:
# Alternatively, load the weights directly: model.load_weights(save_model_path)
model = models.load_model(save_model_path, custom_objects={'bce_dice_loss': bce_dice_loss,
                                                           'dice_coeff': dice_coeff,
                                                           'dice_loss': dice_loss})

In [None]:
plt.figure(figsize=(10, 20))

for i, element in enumerate(val_data_set):
    batch_of_imgs, label = element
    img = batch_of_imgs[0]
    predicted_label = model.predict(batch_of_imgs)[0]
    
    plt.subplot(5, 3, 3 * i + 1)
    plt.imshow(img)
    plt.title("Input image")
    
    plt.subplot(5, 3, 3 * i + 2)
    plt.imshow(label[0, :, :, 0])
    plt.title("Actual Mask")
    
    plt.subplot(5, 3, 3 * i + 3)
    plt.imshow(predicted_label[:, :, 0])
    plt.title("Predicted Mask")

    if i==4:
      break

plt.suptitle("Examples of Input Image, Label, and Prediction")
plt.show()