# U-Net Image Segmentation of Hep2 Images

## Setup

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras import backend as K


## Dataset

In [None]:
def display(display_list):
    plt.figure(figsize=(12, 6))
    title = ["Input Image", "True Mask", "Predicted Mask"]
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]), cmap='gray')
        plt.axis("off")
    plt.show()

def get_image_paths(main_path):
    image_paths, mask_paths = [], []
    for folder in os.listdir(main_path):
        for filename in os.listdir(os.path.join(main_path, folder)):
            if filename.endswith('.bmp'):
                if 'mask' in filename:
                    mask_paths.append(os.path.join(main_path, folder, filename))
                else:
                    image_paths.append(os.path.join(main_path, folder, filename))
    return image_paths, mask_paths

In [None]:
def resize(image, mask):
    image = tf.image.resize(image, (128, 128), method="nearest")
    mask = tf.image.resize(mask, (128, 128), method="nearest")
    return image, mask

def augment(image, mask):
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    return image, mask

def normalize(image, mask):
    image = tf.cast(image, tf.float32) / 255.0
    mask = tf.cast(mask, tf.float32)
    # mask -= 1
    return image, mask

In [None]:
def load_image_train(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_bmp(image, channels=3)
    image = tf.image.rgb_to_grayscale(image)
    image = tf.image.convert_image_dtype(image, "float32")

    mask = tf.io.read_file(mask_path)
    mask = tf.io.decode_bmp(mask, channels=0)
    mask = tf.image.convert_image_dtype(mask, "float32")

    image, mask = resize(image, mask)
    image, mask = augment(image, mask)
    image, mask = normalize(image, mask)

    return image, mask

def load_image_test(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_bmp(image, channels=3)
    image = tf.image.rgb_to_grayscale(image)
    image = tf.image.convert_image_dtype(image, "float32")

    mask = tf.io.read_file(mask_path)
    mask = tf.io.decode_bmp(mask, channels=0)
    mask = tf.image.convert_image_dtype(mask, "float32")

    image, mask = resize(image, mask)
    image, mask = normalize(image, mask)

    return image, mask

In [None]:
val_samples = 7
main_path = 'MIVIA Lab/Main_Dataset/Images'

image_paths, mask_paths = get_image_paths(main_path)
train_image_paths, train_mask_paths = image_paths[:-val_samples], mask_paths[:-val_samples]
test_image_paths, test_mask_paths = image_paths[-val_samples:], mask_paths[-val_samples:]

train_dataset = tf.data.Dataset.from_tensor_slices((train_image_paths, train_mask_paths))
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((test_image_paths, test_mask_paths))
test_dataset = test_dataset.map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
BATCH_SIZE = 4
BUFFER_SIZE = 28

batches = dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
batches = batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

validation_batches = test_dataset.take(7).batch(BATCH_SIZE)
test_batches = test_dataset.skip(0).take(7).batch(BATCH_SIZE)

In [None]:
sample_batch = next(iter(train_batches))
random_index = np.random.choice(sample_batch[0].shape[0])
sample_image, sample_mask = sample_batch[0][random_index], sample_batch[1][random_index]
display([sample_image, sample_mask])

## Model Architecture

In [None]:
def double_conv_block(x, n_filters):
   x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
   x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
   return x

In [None]:
def downsample_block(x, n_filters):
   f = double_conv_block(x, n_filters)
   p = layers.MaxPool2D(2)(f)
   p = layers.Dropout(0.3)(p)
   return f, p

In [None]:
def upsample_block(x, conv_features, n_filters):
   x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
   x = layers.concatenate([x, conv_features])
   x = layers.Dropout(0.3)(x)
   x = double_conv_block(x, n_filters)
   return x

## U-Net Model

In [None]:
def build_unet_model():
    inputs = layers.Input(shape=(128,128,3))
    f1, p1 = downsample_block(inputs, 64)
    f2, p2 = downsample_block(p1, 128)
    f3, p3 = downsample_block(p2, 256)
    f4, p4 = downsample_block(p3, 512)
    bottleneck = double_conv_block(p4, 1024)
    u6 = upsample_block(bottleneck, f4, 512)
    u7 = upsample_block(u6, f3, 256)
    u8 = upsample_block(u7, f2, 128)
    u9 = upsample_block(u8, f1, 64)
    outputs = layers.Conv2D(1, 1, padding="same", activation = "sigmoid")(u9)
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")
    return unet_model

In [None]:
unet_model = build_unet_model()
unet_model.summary()

## Compile and Train U-Net

In [None]:
def f1_score(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val

def dice_loss(targets, inputs, smooth=1e-6):
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    intersection = K.sum(inputs * targets)
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice

In [None]:
unet_model.compile(optimizer=tf.keras.optimizers.SGD(),
                  loss=dice_loss,
                  metrics=['accuracy',
                           tf.keras.metrics.Precision(name='precision'),
                           tf.keras.metrics.Recall(name='recall'),
                           tf.keras.metrics.BinaryIoU(name='iou'),
                           f1_score])

In [None]:
NUM_EPOCHS = 10

TRAIN_LENGTH = len(train_dataset)
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

VAL_SUBSPLITS = 1
TEST_LENTH = len(test_dataset)
VALIDATION_STEPS = TEST_LENTH // BATCH_SIZE // VAL_SUBSPLITS

model_history = unet_model.fit(train_batches,
                              epochs=NUM_EPOCHS,
                              steps_per_epoch=STEPS_PER_EPOCH,
                              validation_steps=VALIDATION_STEPS,
                              validation_data=validation_batches)

## Prediction

In [None]:
def display_learning_curves(history):
    loss = history.history["loss"]
    val_loss = history.history["val_loss"]

    iou=history.history['iou']
    val_iou=history.history['iou']

    acc = history.history["accuracy"]
    val_acc = history.history["val_accuracy"]

    precision=history.history["precision"]
    val_precision=history.history["val_precision"]

    recall=history.history["recall"]
    val_recall=history.history["val_recall"]

    f1 = history.history["f1_score"]
    val_f1 = history.history["val_f1_score"]

    epochs_range = range(NUM_EPOCHS)

    fig = plt.figure(figsize=(12,6))

    plt.subplot(3,2,1)
    plt.plot(epochs_range, loss, label="train loss")
    plt.plot(epochs_range, val_loss, label="validataion loss")
    plt.title("Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="upper left")

    plt.subplot(3,2,2)
    plt.plot(epochs_range, acc, label="train accuracy")
    plt.plot(epochs_range, val_acc, label="validataion accuracy")
    plt.title("Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc="upper right")

    plt.subplot(3,2,3)
    plt.plot(epochs_range, f1, label="train f1_score")
    plt.plot(epochs_range, val_f1, label="validataion f1_score")
    plt.title("F1 Score")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Score")
    plt.legend(loc="lower left")

    plt.subplot(3,2,4)
    plt.plot(epochs_range, iou, label="train Binary IoU")
    plt.plot(epochs_range, val_iou, label="validataion Binary IoU")
    plt.title("Binary IoU")
    plt.xlabel("Epoch")
    plt.ylabel("Binary IoU")
    plt.legend(loc="lower right")

    plt.subplot(3,2,5)
    plt.plot(epochs_range, precision, label="train precision")
    plt.plot(epochs_range, val_precision, label="validataion precision")
    plt.title("Precision")
    plt.xlabel("Epoch")
    plt.ylabel("Precision")
    plt.legend(loc="upper right")

    plt.subplot(3,2,6)
    plt.plot(epochs_range, recall, label="train recall")
    plt.plot(epochs_range, val_recall, label="validataion recall")
    plt.title("Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Recall")
    plt.legend(loc="upper right")

    fig.tight_layout()
    plt.show()

display_learning_curves(model_history)

In [None]:
def show_predictions(dataset, num):
    for image, mask in dataset.take(num):
        pred_mask = unet_model.predict(image)
        pred_mask = tf.cast(pred_mask > 0.5, tf.uint8)
        pred_mask = tf.cast(pred_mask * 255.0, tf.uint8)
        display([image[0], mask[0], pred_mask[0]])

show_predictions(test_batches, 1)