# U-Net Workshop

Here is an example U-Net implementation using minified data based upon creative commons dataset available at https://wiki.cancerimagingarchive.net/display/Public/HNSCC.

In [None]:
# Copyright 2021 Radiotherapy AI Pty Ltd

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import pathlib
import random
import shutil
import urllib.request

import imageio
import matplotlib.pyplot as plt
import tensorflow as tf

In [None]:
BATCH_SIZE = 512
GRID_SIZE = 64

COLOURS_AND_LABELS = [
    ("#ff7f0e", "left parotid"),
    ("#2ca02c", "right parotid"),
    ("#d62728", "external"),
]
NUM_CONTOURS = len(COLOURS_AND_LABELS)

## Download the data

In [None]:
zip_url = "https://github.com/RadiotherapyAI/unet-workshop/releases/download/mini-parotid/mini-parotid.zip"
zip_filepath = "data.zip"

data_directory = pathlib.Path("data")

if not data_directory.exists():
    urllib.request.urlretrieve(zip_url, zip_filepath)
    shutil.unpack_archive(zip_filepath, data_directory)

In [None]:
dataset_types = [path.name for path in data_directory.glob("*") if path.is_dir()]
dataset_types

## Build the TensorFlow pipeline

In [None]:
def get_path_pairs(dataset_type):
    training_image_paths = list((data_directory / dataset_type).glob("*/*.image.png"))
    training_mask_paths = [
        path.parent / f"{path.name.split('.')[0]}.masks.png"
        for path in training_image_paths
    ]

    path_pairs = [
        (str(image), str(mask))
        for image, mask in zip(training_image_paths, training_mask_paths)
    ]

    return sorted(path_pairs)

In [None]:
@tf.function
def load(path_pair):
    image_path = path_pair[0]
    masks_path = path_pair[1]
    
    image_raw = tf.io.read_file(image_path)
    image = tf.io.decode_image(image_raw, channels=1, dtype=tf.uint8)

    masks_raw = tf.io.read_file(masks_path)
    masks = tf.io.decode_image(masks_raw, channels=3, dtype=tf.uint8)

    return image / 255, masks / 255

In [None]:
def create_datasets(dataset_type):
    path_pairs = get_path_pairs(dataset_type)
    dataset = tf.data.Dataset.from_tensor_slices(path_pairs)
    dataset = dataset.shuffle(len(path_pairs), reshuffle_each_iteration=True)
    dataset = dataset.map(load)

    batched_dataset = dataset.batch(BATCH_SIZE)
    batched_dataset = batched_dataset.prefetch(tf.data.AUTOTUNE)

    return dataset, batched_dataset

In [None]:
training_dataset, training_batched_dataset = create_datasets("training")
validation_dataset, validation_batched_dataset = create_datasets("validation")

In [None]:
image, masks = iter(validation_dataset.take(1)).next()

In [None]:
plt.imshow(image[:, :, 0])

In [None]:
plt.imshow(masks)

In [None]:
def plot_contours(ax, image, masks):
    ax.imshow(image[:, :, 0], cmap="gray")

    for i, (colour, label) in enumerate(COLOURS_AND_LABELS):
        c = ax.contour(masks[..., i], colors=[colour], levels=[0.5])
        c.collections[0].set_label(label)

    ax.axis("equal")
    ax.legend()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
plot_contours(ax, image, masks)

In [None]:
def activation(x):
    x = tf.keras.layers.Activation("relu")(x)

    return x


def convolution(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2D(
        number_of_filters, kernel_size, padding="same", kernel_initializer="he_normal"
    )(x)

    return x


def conv_transpose(x, number_of_filters, kernel_size=3):
    x = tf.keras.layers.Conv2DTranspose(
        number_of_filters,
        kernel_size,
        strides=2,
        padding="same",
        kernel_initializer="he_normal",
    )(x)

    return x

In [None]:
def encode(
    x,
    number_of_filters,
    number_of_convolutions=2,
):
    for _ in range(number_of_convolutions):
        x = convolution(x, number_of_filters)
        x = activation(x)
    skip = x

    x = tf.keras.layers.MaxPool2D()(x)
    x = activation(x)

    return x, skip


def decode(
    x,
    skip,
    number_of_filters,
    number_of_convolutions=2,
):
    x = conv_transpose(x, number_of_filters)
    x = activation(x)

    x = tf.keras.layers.concatenate([skip, x], axis=3)

    for _ in range(number_of_convolutions):
        x = convolution(x, number_of_filters)
        x = activation(x)

    return x

In [None]:
inputs = tf.keras.layers.Input((GRID_SIZE, GRID_SIZE, 1))

x = inputs
skips = []

for number_of_filters in [32, 64, 128]:
    x, skip = encode(x, number_of_filters)
    skips.append(skip)

skips.reverse()

for number_of_filters, skip in zip([256, 128, 64], skips):
    x = decode(x, skip, number_of_filters)

x = tf.keras.layers.Conv2D(
    NUM_CONTOURS,
    1,
    activation="sigmoid",
    padding="same",
    kernel_initializer="he_normal",
)(x)

model = tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
model.summary()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision(),
    ],
)

In [None]:
image.shape

In [None]:
pred_masks = model.predict(image[None, ...])[0, ...]

In [None]:
def plot_with_prediction(image, masks, pred_masks):
    fig, ax = plt.subplots(figsize=(12, 6), ncols=2)
    plot_contours(ax[0], image, masks)
    plot_contours(ax[1], image, pred_masks)


plot_with_prediction(image, masks, pred_masks)

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        pred_masks = model.predict(image[None, ...])[0, ...]
        plot_with_prediction(image, masks, pred_masks)

        plt.show()
        print("\nSample Prediction after epoch {}\n".format(epoch + 1))

In [None]:
history = model.fit(
    training_batched_dataset,
    epochs=50,
    validation_data=validation_batched_dataset,
    callbacks=[DisplayCallback()],
)

In [None]:
plt.semilogy(history.history["loss"], label="Training loss")
plt.semilogy(history.history["val_loss"], label="Validation loss")
plt.legend()