<a href="https://colab.research.google.com/github/akhanf/biophys9709/blob/2025/2025_Lecture_10B_Unet_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unet Demo

- 2D images and segmentations
- ImageDataGenerators to get the images and segmentations
- Will use padding to avoid size discrepancy at input/output
- Explore how a custom loss can be defined and how this affects the result





In [None]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np

## Get Data

I created a brain dataset for this example.

Let's load it up to see what it looks like

In [None]:
# load and examine the data
import os

#download the zip file
if not os.path.exists('brain_2d_seg_data.zip'):
  !wget https://www.dropbox.com/s/3so8n63ast4dfcg/brain_2d_seg_data.zip
  !unzip brain_2d_seg_data.zip > /dev/null


## Take a look at what we have downloaded

## Loading the dataset

In previous versions of Keras we would instantiate an ImageDataGenerator, and use flow_from_directory() to read data whil;e performing augmentation.

Now, we will instead use the Keras `image_dataset_from_directory()` function to read the dataset, then apply augmentation as needed.

We wrap this function along with a data normalization step since we will apply it for images and masks in the training and test dataset.

In [None]:
# Function to load datasets
def load_image_dataset(directory, subset, image_size=(160, 160), batch_size=1, seed=1):
    dataset = tf.keras.utils.image_dataset_from_directory(
        directory,
        labels=None,  # Since it's segmentation, we don't need labels
        seed=seed,
        image_size=image_size,
        batch_size=batch_size,
        color_mode="grayscale",  # Ensure grayscale loading
        shuffle=False  # Keep order intact, as we will shuffle later..
    )
    # Normalize images (rescale pixel values to [0,1])
    return dataset.map(lambda x: x/255.0)

# Load images and masks for training
image_dataset_train = load_image_dataset("brain_2d_seg_data/training/images", subset="training")
mask_dataset_train = load_image_dataset("brain_2d_seg_data/training/brain_masks", subset="training")

# Load images and masks for testing
image_dataset_test = load_image_dataset("brain_2d_seg_data/test/images", subset="test")
mask_dataset_test = load_image_dataset("brain_2d_seg_data/test/brain_masks", subset="test")


# Zip images and masks together
train_dataset = tf.data.Dataset.zip((image_dataset_train, mask_dataset_train))
test_dataset = tf.data.Dataset.zip((image_dataset_test, mask_dataset_test))


In [None]:
# define augmentation function, which takes both image and mask:
from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom, RandomTranslation, RandomBrightness


augmentation_layer = tf.keras.Sequential([
    RandomFlip("horizontal"),
    RandomRotation(0.1,fill_mode='constant',fill_value=0),
])


def augment_with_keras(image, mask):
    """Apply the same Keras augmentation to both image and mask."""

    # Stack image & mask to apply the same transformation
    combined = tf.concat([image, mask], axis=-1)

    # Apply augmentation
    augmented = augmentation_layer(combined)

    # Split image and mask back
    augmented_image = augmented[..., 0:1]  # First channel is the image
    augmented_mask = augmented[..., 1:2]  # Second channel is the mask

    return augmented_image, augmented_mask



In [None]:
for img_batch, mask_batch in train_dataset.take(5):  # This should be original images

    aug_img, aug_mask = augment_with_keras(img_batch, mask_batch)  # Apply augmentation manually

    # Visualize
    plt.figure()
    plt.imshow(img_batch[0, :, :, 0], cmap="gray")
    plt.imshow(mask_batch[0, :, :, 0], alpha=0.5)
    plt.colorbar()

    plt.figure()
    plt.imshow(aug_img[0, :, :, 0], cmap="gray")
    plt.imshow(aug_mask[0, :, :, 0], alpha=0.5)
    plt.colorbar()


In [None]:
#apply augmentation, shuffle, repeat..
train_dataset = train_dataset.map(augment_with_keras)
train_dataset = train_dataset.shuffle(buffer_size=100)
train_dataset = train_dataset.repeat()

### Use the Dice metric to evaluate performance

In [None]:
def dice_metric(y_true, y_pred):
     y_true_f = tf.reshape(y_true,[-1])
     y_pred_f = tf.reshape(y_pred,[-1])
     intersection = tf.reduce_sum(y_true_f * y_pred_f)
     score = (2. * intersection + 1.) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1.)

     return score

def dice_loss(y_true, y_pred):
    return (1 - dice_metric(y_true, y_pred))



## Create our model

Use:
- Max-pooling & up-convolutions
- 16 deep in first conv layer
- 2 convolutions in each stage
- 4 skip connections
- padding to get same size outputs


Set each layer to `x`, to make it easier to copy-paste and re-arrange things.



In [None]:
img_shape = (160,160,1)
input_layer = keras.layers.Input(img_shape)

x = keras.layers.Conv2D(16,(3,3), padding='same',activation='relu')(input_layer)
x = keras.layers.Conv2D(16,(3,3), padding='same',activation='relu')(x)
out_layer1 = x
x = keras.layers.MaxPooling2D((2,2))(x)

x = keras.layers.Conv2D(32,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(32,(3,3), padding='same',activation='relu')(x)
out_layer2 = x
x = keras.layers.MaxPooling2D((2,2))(x)

x = keras.layers.Conv2D(64,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(64,(3,3), padding='same',activation='relu')(x)
out_layer3 = x
x = keras.layers.MaxPooling2D((2,2))(x)


x = keras.layers.Conv2D(128,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(128,(3,3), padding='same',activation='relu')(x)
out_layer4 = x
x = keras.layers.MaxPooling2D((2,2))(x)

x = keras.layers.Conv2D(256,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(256,(3,3), padding='same',activation='relu')(x)

x = keras.layers.UpSampling2D((2,2))(x)
x = keras.layers.Conv2D(128,(2,2), padding='same',activation='relu')(x)

x = keras.layers.Concatenate(axis=3)([out_layer4,x])
x = keras.layers.Conv2D(128,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(128,(3,3), padding='same',activation='relu')(x)


x = keras.layers.UpSampling2D((2,2))(x)
x = keras.layers.Conv2D(64,(2,2), padding='same',activation='relu')(x)

x = keras.layers.Concatenate(axis=3)([out_layer3,x])
x = keras.layers.Conv2D(64,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(64,(3,3), padding='same',activation='relu')(x)


x = keras.layers.UpSampling2D((2,2))(x)
x = keras.layers.Conv2D(32,(2,2), padding='same',activation='relu')(x)


x = keras.layers.Concatenate(axis=3)([out_layer2,x])
x = keras.layers.Conv2D(32,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(32,(3,3), padding='same',activation='relu')(x)

x = keras.layers.UpSampling2D((2,2))(x)
x = keras.layers.Conv2D(16,(2,2), padding='same',activation='relu')(x)

x = keras.layers.Concatenate(axis=3)([out_layer1,x])
x = keras.layers.Conv2D(16,(3,3), padding='same',activation='relu')(x)
x = keras.layers.Conv2D(16,(3,3), padding='same',activation='relu')(x)

#1x1 conv with sigmoid to get binary classification at each pixel
x = keras.layers.Conv2D(1,(1,1), padding='same',activation='sigmoid')(x)

model = keras.Model(input_layer, x)

###  Compile and visualize it

In [None]:

# Compile the model
optimizer = keras.optimizers.Adam()
loss = ['binary_crossentropy',dice_loss]
metrics = ['binary_accuracy',dice_metric]


model.compile(loss=loss,
              optimizer=optimizer,
              metrics=metrics)

# What does the finished model look like?
model.summary()
keras.utils.plot_model(model, show_shapes=True, rankdir='TD')

## Fit the model

The fit() function takes x and y -- we want to pass the image and mask  correspondingly.

But since we are using generators, we need the output of the generator to be: `(image, mask)`

We can create this behaviour by using `zip()` to zip together the image and mask generators.




In [None]:
for x, y in train_dataset.take(1):
    print("Image shape:", x.shape)
    print("Mask shape:", y.shape)


In [None]:
# combine generators into one which yields both image and masks
#train_generator = zip(train_dataset, mask_generator_train)
#test_generator = zip(image_generator_test, mask_generator_test)

history = model.fit(train_dataset,
                    steps_per_epoch=100,
                    epochs=10)



Plot the loss and metrics on training data

In [None]:
#plot loss and metrics
import pandas as pd
import seaborn as sns
df = pd.DataFrame(history.history)
sns.lineplot(data=df)


Evaluate the metrics on the test data

In [None]:
#model.evaluate to get avg metric
metrics = model.evaluate(test_dataset,steps=197,return_dict=True)
print(metrics)

Let's take a look at some results for the test dataset

In [None]:
#plot some examples from the test set

for sample,(image,mask) in enumerate(test_dataset):

  predicted = model.predict(image)

  plt.figure()

  plt.subplot(1,2,1)
  plt.imshow(np.squeeze(image),cmap='gray')
  plt.imshow(np.squeeze(mask),alpha=0.5)

  plt.subplot(1,2,2)
  plt.imshow(np.squeeze(image),cmap='gray')
  plt.imshow(np.squeeze(predicted),alpha=0.5)


  if sample > 5:
    break


## Now what?

- Try segmenting the ventricles with this configuration

- Try using a custom loss function



## Beyond this example

### More than just binary labels?

We took a shortcut here and used a single channel output and sigmoid to get our label. More generally if you have multiple labels, you will want to use a *one-hot* encoding analogous to what we did for multi-class classification, and use a soft-max activation.

 E.g. if we wanted to have our example data (brain and ventricle segmentation) set up in this way, we would need to have the following 3 channels:
 - background: 1 where neither brain nor ventricles are, 0 elsewhere
 - brain: 1 where brain is, 0 elsewhere (we have this already)
 - ventricles: 1 where ventricles are, 0 elsewhere (we have this already)

So in this case you would need to create the background channel, e.g. via logical operations on the other channels.

If you are starting with a single image with multiple labels on it, then you should be able to use Keras' built-in `to_categorical()` to achieve this.

Note that if you have multiple labels, you would also need to account for this in any custom loss functions you create.

### Validation split?
We didn't use a validation split here for the sake of simplicity. You can achieve this by either putting your validation data in another directory and creating another Dataset, **or** you can use specify `subset='validation'` or `subset='training'` when defining the dataset to perform the split.
