# U-Net Workshop

Here is an example 2D U-Net implementation using minified data based upon creative commons dataset available at https://wiki.cancerimagingarchive.net/display/Public/HNSCC.

In [None]:
# Copyright 2021 Radiotherapy AI Pty Ltd

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Overview

* Library imports and namespaces
* Utilising glob and pathlib
* Plotting with matplotlib
* Building a 2D UNet
* Setting up a callback
* Training the model and viewing the results

### Some programming principles

* Prototype to learn
* Don't repeat yourself
* Don't assume it, prove it
* Design self-contained independent well-defined reusable components


### Resources

* [DRY -- "Every piece of knowledge must have a single, unambiguous, authoritative representation within a system"](http://media.pragprog.com/titles/tpp20/dry.pdf)
* [NumPy for MATLAB users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html)
* [PyTest primer](https://www.tutorialspoint.com/pytest/pytest_quick_guide.htm)

In [None]:
# First things first, describe the Google Colab interface
# Make sure everyone can run a hello world.
# Swap to CPU for now

## Library imports

Here are a set of library imports, from both the standard library and some libraries downloadable from PyPI. These are imported within namespaces so as not to variable and function name conflicts.

In [None]:
import this

In [None]:
import pathlib
import random
import shutil
import urllib.request

import imageio
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
# Show the effects of importing math with and without namespacing

## Constants

In [None]:
BATCH_SIZE = 512
GRID_SIZE = 64

DATASET_TYPES = {"hold-out", "training", "validation"}

COLOURS_AND_LABELS = [
    ("#ff7f0e", "left parotid"),
    ("#2ca02c", "right parotid"),
    ("#d62728", "external"),
]
# Single authoratative representation of knowledge
NUM_CONTOURS = len(COLOURS_AND_LABELS)

In [None]:
IMAGE_DIMENSIONS = (GRID_SIZE, GRID_SIZE, 1)
MASK_DIMENSIONS = (GRID_SIZE, GRID_SIZE, NUM_CONTOURS)

In [None]:
# Overview the set, list, and tuple types
# The len function
# Iteration over a list

In [None]:
EXPECTED_BATCH_IMAGE_DIMENSIONS = (
    BATCH_SIZE,
    *IMAGE_DIMENSIONS,
)
EXPECTED_BATCH_MASK_DIMENSIONS = (
    BATCH_SIZE,
    *MASK_DIMENSIONS,
)

In [None]:
# Demonstrate the effect of tuple unpacking

## Download and investigate the data

In [None]:
# Prototype data, prototype to learn

zip_url = (
    "https://github.com/RadiotherapyAI/"
    "unet-workshop/releases/download/"
    "mini-parotid/mini-parotid.zip"
)
zip_url

In [None]:
# Investigate the downloadable data within a filebrowser

In [None]:
zip_filepath = "data.zip"

data_directory = pathlib.Path("data")

if not data_directory.exists():
    urllib.request.urlretrieve(zip_url, zip_filepath)
    shutil.unpack_archive(zip_filepath, data_directory)

In [None]:
# Investigate the downloaded data with pathlib and glob
# Load an image with imageio and create a plot with
# matplotlib's imshow

## After demo -- BREAKOUT ROOMS -- gather feedback on tutorial's pace ##
# -- GOAL: Have everyone able to plot
#          the downloaded data, demoed first

In [None]:
dataset_types_found = {
    path.name
    for path in data_directory.glob("*")
    if path.is_dir()
}

# Don't assume it, prove it
assert dataset_types_found == DATASET_TYPES

## Build the TensorFlow pipeline

In [None]:
# Design re-usable components


def get_image_paths(dataset_type):
    image_paths = list(
        (data_directory / dataset_type).glob(
            "*/*.image.png"
        )
    )

    return image_paths

In [None]:
# Explain the components of a function definition, inputs/outputs
# Demo the usage of this function
# Find the number of image paths for each dataset type

In [None]:
def get_path_pairs(dataset_type):
    image_paths = get_image_paths(dataset_type)
    mask_paths = [
        path.parent / f"{path.name.split('.')[0]}.masks.png"
        for path in image_paths
    ]

    path_pairs = [
        (str(image), str(mask))
        for image, mask in zip(
            image_paths,
            mask_paths,
        )
    ]

    return sorted(path_pairs)

In [None]:
@tf.function
def load(path_pair):
    image_path = path_pair[0]
    masks_path = path_pair[1]

    image_raw = tf.io.read_file(image_path)
    image = tf.io.decode_image(
        image_raw, channels=1, dtype=tf.uint8
    )

    masks_raw = tf.io.read_file(masks_path)
    masks = tf.io.decode_image(
        masks_raw, channels=NUM_CONTOURS, dtype=tf.uint8
    )

    return image / 255, masks / 255

In [None]:
def create_datasets(dataset_type):
    path_pairs = get_path_pairs(dataset_type)
    dataset = tf.data.Dataset.from_tensor_slices(path_pairs)
    dataset = dataset.shuffle(
        len(path_pairs),
        reshuffle_each_iteration=True,
    )
    dataset = dataset.map(load)

    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

In [None]:
datasets = {}

for dataset_type in DATASET_TYPES:
    datasets[dataset_type] = create_datasets(dataset_type)

datasets

In [None]:
# A description on dictionaries, and how they can be used
# to apply a single task to many variables

In [None]:
## BREAKOUT ROOMS ##
# -- Goal: Use a for loop to build a dictionary
#          containing the length of each dataset

In [None]:
# Don't assume it, prove it

for batched_images, batched_masks in datasets["training"]:
    # Include this after seeing the assertion error
    # print(batched_images.shape)
    assert (
        batched_images.shape
        == EXPECTED_BATCH_IMAGE_DIMENSIONS
    )
    assert (
        batched_masks.shape
        == EXPECTED_BATCH_MASK_DIMENSIONS
    )

# After used comment out this whole cell with
# `Ctrl + /` so that "Run All Cells" doesn't
# error out anymore.

In [None]:
# Why did this assertion fail?

In [None]:
def dataset_dimensions_check(dataset_type):
    number_of_images = len(get_image_paths(dataset_type))
    assert number_of_images > 0

    dataset = create_datasets(dataset_type)

    image_shapes = []
    mask_shapes = []

    for batched_images, batched_masks in dataset:
        image_shapes.append(batched_images.shape)
        mask_shapes.append(batched_masks.shape)

    for image_shape, mask_shape in zip(
        image_shapes[:-1], mask_shapes[:-1]
    ):
        assert (
            image_shape == EXPECTED_BATCH_IMAGE_DIMENSIONS
        )
        assert mask_shape == EXPECTED_BATCH_MASK_DIMENSIONS

    remaining_batch_size = number_of_images % BATCH_SIZE

    assert image_shapes[-1] == (
        remaining_batch_size,
        *IMAGE_DIMENSIONS,
    )
    assert mask_shapes[-1] == (
        remaining_batch_size,
        *MASK_DIMENSIONS,
    )


# Can move a function like this into its own Python file
# named `test_dataset.py` to allow it to be picked up by
# the automated testing framework `PyTest`.
# https://www.tutorialspoint.com/pytest/pytest_quick_guide.htm


def test_dataset_dimensions():
    for dataset_type in DATASET_TYPES:
        print(dataset_type)
        dataset_dimensions_check(dataset_type)


# Need to 'test your tests'
test_dataset_dimensions()

# Verify that this test appropriately fails if items above are changed

## BREAKOUT ROOMS ##
# -- Goal: Everyone able to break and then repair the above test

## Investigate the created pipeline

In [None]:
batch_validation_images, batch_validation_masks = iter(
    datasets["validation"].take(1)
).next()

In [None]:
# Investigate the shapes of the batched images and masks

In [None]:
image = batch_validation_images[0, ...]
masks = batch_validation_masks[0, ...]

In [None]:
# Investigate the shapes of the newly indexed objects

In [None]:
# Use imshow to view these images and masks

## Create a useful representation of the data

In [None]:
def plot_contours(ax, image, masks):
    ax.imshow(image[:, :, 0], cmap="gray")

    for i, (colour, label) in enumerate(COLOURS_AND_LABELS):
        if np.all(masks[..., i] < 0.5) or np.all(
            masks[..., i] > 0.5
        ):
            continue

        c = ax.contour(
            masks[..., i],
            colors=[colour],
            levels=[0.5],
        )
        c.collections[0].set_label(label)

    ax.axis("equal")
    ax.legend()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
# Use the plot_contours function

## BREAKOUT ROOMS ##
# -- Goal: Everyone able to investigate the validation images and masks

## Building the 2D U-Net model

In this section we will create a Tensorflow Keras 2D UNet model utilising a set of pre-built functions. An example UNet diagram is given below for aiding explanation:

![](https://github.com/RadiotherapyAI/unet-workshop/blob/019f25013030e51b83e2370b347bf5933aebc37c/images/unet.png?raw=1)

In [None]:
def activation(x):
    x = tf.keras.layers.Activation("relu")(x)

    return x


def convolution(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2D(
        number_of_filters,
        kernel_size,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    return x


def conv_transpose(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2DTranspose(
        number_of_filters,
        kernel_size,
        strides=2,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    return x

In [None]:
# Highlight where the activation, convolution, and conv_transpose occurs in the UNet diagram

In [None]:
def encode(
    x,
    number_of_filters,
    number_of_convolutions=2,
):
    """An encoding layer within a 2D UNet"""
    for _ in range(number_of_convolutions):
        x = convolution(x, number_of_filters)
        x = activation(x)
    skip = x

    x = tf.keras.layers.MaxPool2D()(x)
    x = activation(x)

    return x, skip


def decode(
    x,
    skip,
    number_of_filters,
    number_of_convolutions=2,
):
    """A decoding layer within a 2D UNet"""
    x = conv_transpose(x, number_of_filters)
    x = activation(x)

    x = tf.keras.layers.concatenate([skip, x], axis=-1)

    for _ in range(number_of_convolutions):
        x = convolution(x, number_of_filters)
        x = activation(x)

    return x

In [None]:
# Highlight the encode and decode sections of the UNet

In [None]:
def get_unet_filter_counts(grid_size):
    """Return a reasonable set of convolution filter sizes for a UNet"""
    network_depth = int(np.log2(grid_size / 8))
    encoding_filter_counts = 2 ** (
        np.array(range(network_depth)) + 5
    )
    decoding_filter_counts = (
        2 ** (np.array(range(network_depth)) + 6)[::-1]
    )

    return (
        encoding_filter_counts,
        decoding_filter_counts,
    )

In [None]:
# Show the effect of a range of different grid_sizes

In [None]:
def unet(grid_size, num_contours):
    """Create a bare-bones 2D UNet"""
    inputs = tf.keras.layers.Input(
        (grid_size, grid_size, 1)
    )

    (
        encoding_filter_counts,
        decoding_filter_counts,
    ) = get_unet_filter_counts(grid_size)

    x = inputs
    skips = []

    for number_of_filters in encoding_filter_counts:
        x, skip = encode(x, number_of_filters)
        skips.append(skip)

    skips.reverse()

    for number_of_filters, skip in zip(
        decoding_filter_counts, skips
    ):
        x = decode(x, skip, number_of_filters)

    x = tf.keras.layers.Conv2D(
        num_contours,
        1,
        activation="sigmoid",
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    model = tf.keras.Model(inputs=inputs, outputs=x)

    return model

In [None]:
model = unet(GRID_SIZE, NUM_CONTOURS)
model.summary()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision(),
    ],
)

In [None]:
# Utilise the untrained model to create a prediction

In [None]:
# Use the previously defined plot_contour to show this prediction

In [None]:
def plot_with_prediction(image, masks, pred_masks):
    fig, ax = plt.subplots(figsize=(12, 6), ncols=2)
    plot_contours(ax[0], image, masks)
    plot_contours(ax[1], image, pred_masks)

In [None]:
# Use this new function, plot_with_prediction, to see a comparison

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        pred_masks = model.predict(image[None, ...])[0, ...]
        plot_with_prediction(image, masks, pred_masks)

        plt.show()
        print(
            "\nSample Prediction after"
            " epoch {}\n".format(epoch + 1)
        )

In [None]:
# Describe classes
# Instantiate this callback class and
# call the on_epoch_end method to see what it does

## Training

In [None]:
# Swap to GPU runtime
# Run all cells

In [None]:
history = model.fit(
    datasets["training"],
    epochs=50,
    validation_data=datasets["validation"],
    callbacks=[DisplayCallback()],
)

In [None]:
# Compare the training results. Sometimes it won't converge.
# Can re-initialise the model and re-train in that case.

## BREAKOUT ROOMS ##
# -- Goal: Catch everyone up, and have everyone be able to train a UNet

In [None]:
plt.semilogy(history.history["loss"], label="Training loss")
plt.semilogy(
    history.history["val_loss"],
    label="Validation loss",
)
plt.legend()