<a href="https://colab.research.google.com/github/Lorenzo-B/chaos-biocv/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Libaries

In [None]:
%pip install pydicom

In [None]:
import tensorflow as tf
import keras as keras
import tensorflow_datasets as tfds
from keras.callbacks import ModelCheckpoint

import matplotlib.pyplot as plt
import numpy as np

from pathlib import Path

# Clone Dataset

In [None]:
!git clone https://github.com/Lorenzo-B/chaos-biocv.git

In [None]:
%cd chaos-biocv/chaos_dataset/
!tfds build
%cd /content/

In [None]:
dataset, info = tfds.load('chaos_dataset:1.0.12', with_info=True)

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 32
BUFFER_SIZE = TRAIN_LENGTH
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

# Load Dataset

## Preprocessing and data Augmentation

In [None]:
def resize(input_image: tf.Tensor, input_mask: tf.Tensor):
   input_image = tf.image.resize(input_image, (128, 128))
   input_mask = tf.image.resize(input_mask, (128, 128), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

   return input_image, input_mask

def normalize(input_image: tf.Tensor, input_mask: tf.Tensor):
   return input_image, input_mask

class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.

        self.augment_inputs = keras.layers.RandomFlip(mode="vertical", seed=seed)
        self.augment_masks = keras.layers.RandomFlip(mode="vertical", seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_masks(labels)
        return inputs, labels


def load_image(datapoint):
   input_image = datapoint["image"]
   input_mask = datapoint["segmentation_mask"]

   input_image, input_mask = resize(input_image, input_mask)
   input_image, input_mask = normalize(input_image, input_mask)

   return input_image, input_mask


## Prepare dataset

In [None]:
train_dataset = dataset["train"].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = dataset["validation"].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = dataset["test"].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
train_batches = (
    train_dataset
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

validation_batches = validation_dataset.batch(BATCH_SIZE)
test_batches = test_dataset.batch(BATCH_SIZE)

# Helper functions

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        # plt.imshow(tf.keras.utils.array_to_img(display_list[i]), cmap="gray")
        plt.axis('off')
    plt.show()

In [None]:
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])

# U-Net Model

In [None]:
def double_conv_block(x, n_filters):

    # Conv2D then ReLU activation
    x = keras.layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    # Conv2D then ReLU activation
    x = keras.layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)

    return x

def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = keras.layers.MaxPool2D(2)(f)
    p = keras.layers.Dropout(0.3)(p)

    return f, p

def upsample_block(x, conv_features, n_filters):
    # upsample
    x = keras.layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
    # concatenate
    x = keras.layers.concatenate([x, conv_features])
    # dropout
    x = keras.layers.Dropout(0.3)(x)
    # Conv2D twice with ReLU activation
    x = double_conv_block(x, n_filters)

    return x

In [None]:
def build_unet_model():

    # inputs
    inputs = keras.layers.Input(shape=(128,128,1))

    # encoder: contracting path - downsample
    # 1 - downsample
    f1, p1 = downsample_block(inputs, 64)
    # 2 - downsample
    f2, p2 = downsample_block(p1, 128)
    # 3 - downsample
    f3, p3 = downsample_block(p2, 256)
    # 4 - downsample
    f4, p4 = downsample_block(p3, 512)

    # 5 - bottleneck
    bottleneck = double_conv_block(p4, 1024)

    # decoder: expanding path - upsample
    # 6 - upsample
    u6 = upsample_block(bottleneck, f4, 512)
    # 7 - upsample
    u7 = upsample_block(u6, f3, 256)
    # 8 - upsample
    u8 = upsample_block(u7, f2, 128)
    # 9 - upsample
    u9 = upsample_block(u8, f1, 64)

    # last = tf.keras.layers.Conv2DTranspose(
    #   filters=5, kernel_size=3, strides=2,
    #   padding='same')  #64x64 -> 128x128
    # outputs = last(u9)

    # outputs
    outputs = keras.layers.Conv2D(filters=5, kernel_size=1, padding="same", activation = "softmax")(u9)

    # unet model with Keras Functional API
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")

    return unet_model

In [None]:
unet_model = build_unet_model()

In [None]:
unet_model.summary()

In [None]:
unet_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = unet_model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
                create_mask(unet_model.predict(sample_image[tf.newaxis, ...]))])

## Predictions (before training)

In [None]:
show_predictions()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        #clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([1.0, 10.0, 5.0, 6.0, 6.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

## Training

In [None]:
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['validation'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_checkpoint = ModelCheckpoint('chaos19_unet.keras', monitor='loss', verbose=1, save_best_only = True)
model_history = unet_model.fit(train_batches.map(add_sample_weights), epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=validation_batches,
                          callbacks=[model_checkpoint, DisplayCallback()])

# Training and validation loss

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

# Predictions

In [None]:
model = build_unet_model()
model.load_weights("chaos19_unet.keras")
show_predictions(test_batches, 20)

# show_predictions(test_batches, 30)