<img src="https://futurejobs.my/wp-content/uploads/2021/05/d-min-1024x297.png" width="300"> </img>

> **Copyright &copy; 2021 Skymind Education Group Sdn. Bhd.**<br>
 <br>
This program and the accompanying materials are made available under the
terms of the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). \
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations
under the License. <br>
<br>**SPDX-License-Identifier: Apache-2.0** 

# Image Segmentation 


## Introduction

Object segmentation is the process of finding the boundaries of target objects in images. There are many applications for segmenting objects in an image. For example, by outlining anatomical objects in medical images, clinical experts can learn useful information about patient's conditions. The goal of automatic single-object segmentation is to predict a binary mask given in an image, whereby the object of interest will be in white and the background is black. 

<img src="https://rumc-gcorg-p-public.s3.amazonaws.com/b/265/bannerV3_V5OH10E.x15.jpeg"></img>

This hands-on will guide you through building a pipeline to automatically segment fetal head in ultrasound images, from scratch.

_Authored by: [Scotrraaj Gopal](http://github.com/scotgopal)_

## Objectives

In this hands-on, we will :-

1. Download the fetal head dataset.
2. Create a custom Dataset object.
3. Define the deep learning model.
4. Define the loss function and optimizer.
5. Train the model.
6. Test the model.

## Data Acquisition

In [None]:
from urllib import request
import zipfile
from pathlib import Path
from tqdm import tqdm

In [None]:
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

download_links = ["https://zenodo.org/record/1322001/files/training_set.zip?download=1","https://zenodo.org/record/1322001/files/test_set.zip?download=1"]
destination_files = ["training_set.zip", "test_set.zip"]
DATASET_BASE_PATH = Path("../datasets/FetalHeadDataset").resolve()

if not DATASET_BASE_PATH.exists(): DATASET_BASE_PATH.mkdir()

for download_link, destination_file in zip(download_links, destination_files):
    destination_file = Path.joinpath(DATASET_BASE_PATH, destination_file)
    if not destination_file.exists():
        with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=download_link.split('/')[-1]) as t:
            request.urlretrieve(download_link, destination_file, reporthook=t.update_to)
        zipr = zipfile.ZipFile(destination_file)
        zipr.extractall(DATASET_BASE_PATH)
        zipr.close()
    else:
        print(f"{destination_file} already exists, skipping download!")

## Data Exploration

According to the source of the documentation [here](https://hc18.grand-challenge.org/), this dataset was posted as part of a challenge to enable automated fetal head segmentation. The description includes information that the training set contains 800x540 images; 999 for the training set and 335 for the test set. The training set includes an image of the **manual annotation** of the head circumference for each image. 

Let's verify it by first opening the downloaded dataset directory and viewing the data itself. If you had opened the downloaded dataset, you'd realize that there are many `.png` files in both `training_set` and `test_set` folders, but the targets' file names are prefixed with "_Annotation". Let's programmatically verify the counts of the images.

In [None]:
train_val_dir = Path.joinpath(DATASET_BASE_PATH, "training_set")
feature_image_list = []
target_image_list = []
for path in list(train_val_dir.glob("*")):
    if ".png" in path.name:
        if "Annotation" in path.name:
            target_image_list.append(path)
        else:
            feature_image_list.append(path)

print("Total training images:", len(feature_image_list))
print("Total target images:", len(target_image_list))

Let's visualize a small sample of these training images along with their annotations.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from PIL import Image
from skimage.segmentation import mark_boundaries

np.random.seed(0)
plt.rcParams["figure.figsize"] = [17,10]

total_images_to_display = 3
random_image_paths = np.random.choice(feature_image_list, total_images_to_display)
corresponding_target_paths = [str(image_path).replace(".png", "_Annotation.png") for image_path in random_image_paths]

def get_masked_image(feature_image, target_image):
    """Function to plot the boundaries of the binary mask onto the input image"""
    # mark_boundaries returns a 3-channel image, ranging from 0 - 1
    normalized_masked_image = mark_boundaries(np.array(feature_image), np.array(target_image), color=(0,1,0), mode="thick")
    masked_image = (normalized_masked_image*255).astype(np.uint8)
    return masked_image

for feature_image, target_image in zip(random_image_paths, corresponding_target_paths):
    feature_image_pil = Image.open(feature_image)
    target_image_pil = Image.open(target_image)
    masked_image = get_masked_image(feature_image_pil, target_image_pil)
    plt.figure()
    plt.subplot(1,3,1)
    plt.imshow(feature_image_pil, cmap="gray")
    plt.subplot(1,3,2)
    plt.imshow(target_image_pil, cmap="gray")
    plt.subplot(1,3,3)
    plt.imshow(masked_image)

## Data Transformation

There are many packages that we can choose from to perform transformation on images. There's the built-in `transforms` package in `torchvision`. There're also third-party packages such as [`Albumentations`](https://albumentations.ai/) and [`imgaug`](https://imgaug.readthedocs.io/en/latest/).

In this hands-on, we will show you how to transform the dataset using `Albumentations`.

In [None]:
from albumentations import Compose, Resize, HorizontalFlip, VerticalFlip

height, width = 128, 192
transform_train = Compose([Resize(height=height, width=width), HorizontalFlip(p=0.5), VerticalFlip(p=0.5)])
transform_val = Compose([Resize(height=height, width=width)])

Let's demonstrate what each transformer does to an input feature image, annotation image and a masked image.

_Note: Since there is an element of randomness (p) to each transformation, every run of the following cell will return different output._

In [None]:
for feature_image, target_image in zip(random_image_paths, corresponding_target_paths):
    feature_image_pil = Image.open(feature_image)
    target_image_pil = Image.open(target_image)
    masked_image = get_masked_image(feature_image_pil, target_image_pil)

    transformer_output_dict = transform_train(image=np.array(feature_image_pil), mask=np.array(target_image_pil))
    feature_image_transformed = transformer_output_dict['image']
    target_image_transformed = transformer_output_dict['mask']
    masked_image_transformed = get_masked_image(feature_image_transformed, target_image_transformed)

    plt.figure()
    plt.subplot(3, 2, 1)
    plt.imshow(feature_image_pil, cmap="gray")
    plt.title("feature image before transform")

    plt.subplot(3, 2, 2)
    plt.imshow(feature_image_transformed, cmap="gray")
    plt.title("feature image after transform")

    plt.subplot(3, 2, 3)
    plt.imshow(target_image_pil, cmap="gray")
    plt.title("target image before transform")

    plt.subplot(3, 2, 4)
    plt.imshow(target_image_transformed, cmap="gray")
    plt.title("target image after transform")

    plt.subplot(3, 2, 5)
    plt.imshow(masked_image, cmap="gray")
    plt.title("masked image before transform")

    plt.subplot(3, 2, 6)
    plt.imshow(masked_image_transformed, cmap="gray")
    plt.title("masked image after transform")
    break

## Custom Dataset Class

Up next, we are going to define a custom Dataset class to ease our process of extracting images in demand. Using `torch.utils.data.Dataset` as our base class, we will extend this class to create a custom child class, `FetalHead_Dataset`.

Ideally, we will use this class to later create three objects; `train_ds`, `val_ds` and finally `test_ds`. We will split the images from the `training_set` folder into `train_ds` and `val_ds` to train our model. Then we will use `test_ds` to sample some images from `test_set` folder to see how our model performs.

_Note: The images in the `test_set` folder were originally provided for participants of this challenge to have a standard way of testing their models and making a submission. We will just use the images from this folder to validate if the model is able to perform with unseen data._

In [None]:
import torch
from torch.utils.data import Dataset
from scipy import ndimage

class FetalHead_Dataset(Dataset):
    def __init__(self, images_dir:Path, type, transform=None):
        super().__init__()
        _valid_types = ["train","val","test"]
        if type not in _valid_types:
            raise ValueError(f"Invalid dataset type: '{type}'. Use one of these: {_valid_types}")

        self.images_dir = images_dir
        self.type = type
        self.transform = transform

        self.feature_image_list = []
        self.target_image_list = []        
        for path in list(self.images_dir.glob("*")):
            if ".png" in path.name:
                if "Annotation" in path.name:
                    self.target_image_list.append(path)
                else:
                    self.feature_image_list.append(path)
    
    def __len__(self):
        return len(self.feature_image_list)
    
    def __getitem__(self, index):
        feature_image_path = self.feature_image_list[index]
        feature_image_array = np.array(Image.open(feature_image_path))
        feature_image_tensor = torch.tensor(feature_image_array, dtype=torch.uint8)

        if self.type in ["train","val"]:
            target_image_path = self.target_image_list[index]
            target_image_pil = Image.open(target_image_path)
            target_image_filled = ndimage.binary_fill_holes(target_image_pil).astype(np.uint8)
            target_image_tensor = torch.tensor(target_image_filled, dtype=torch.uint8)

            if self.transform:
                transformer_output_dict = self.transform(image=feature_image_array, mask=target_image_filled)
                feature_image_transformed = transformer_output_dict["image"]
                target_image_transformed = transformer_output_dict["mask"]
                
                feature_image_tensor = torch.tensor(feature_image_transformed, dtype=torch.uint8)
                target_image_tensor = torch.tensor(target_image_transformed, dtype=torch.uint8)

            return feature_image_tensor, target_image_tensor
        else:
            if self.transform:
                transformer_output_dict = self.transform(image=feature_image_array)
                feature_image_transformed = transformer_output_dict["image"]

                feature_image_tensor = torch.tensor(feature_image_transformed, dtype=torch.uint8)

            return feature_image_tensor              

In [None]:
print(train_val_dir)
test_dir = Path.joinpath(DATASET_BASE_PATH, "test_set")
print(test_dir)

fetal_train_ds = FetalHead_Dataset(train_val_dir, "train", transform_train)
fetal_val_ds = FetalHead_Dataset(train_val_dir, "val", transform_val)
fetal_test_ds = FetalHead_Dataset(test_dir, "test", transform_val)

Let's inspect the properties of the three dataset objects created.

In [None]:
train_feature, train_target = fetal_train_ds[0]
train_masked = get_masked_image(train_feature, train_target)

val_feature, val_target = fetal_val_ds[0]
val_masked = get_masked_image(val_feature, val_target)

test_feature = fetal_test_ds[0]

print("Total training features: " ,len(fetal_train_ds))
print("-"*5, "Train Feature", "-"*5)
print("Shape:", train_feature.shape, "Type:", train_feature.dtype, "Max:", train_feature.max())
print("-"*5, "Train Target", "-"*5)
print("Shape:", train_target.shape, "Type:", train_target.dtype, "Max:", train_target.max())
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.title("Train feature")
plt.imshow(train_feature, cmap="gray")
plt.subplot(1,3,2)
plt.title("Train target")
plt.imshow(train_target, cmap="gray")
plt.subplot(1,3,3)
plt.title("Train masked")
plt.imshow(train_masked, cmap="gray")

print("\nTotal validation features: " ,len(fetal_val_ds))
print("-"*5, "Val Feature", "-"*5)
print("Shape:", val_feature.shape, "Type:", val_feature.dtype, "Max:", val_feature.max())
print("-"*5, "Val Target", "-"*5)
print("Shape:", val_target.shape, "Type:", val_target.dtype, "Max:", val_target.max())
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.title("Val feature")
plt.imshow(val_feature, cmap="gray")
plt.subplot(1,3,2)
plt.title("Val target")
plt.imshow(val_target, cmap="gray")
plt.subplot(1,3,3)
plt.title("Val masked")
plt.imshow(val_masked, cmap="gray")

print("\nTotal test features: " ,len(fetal_test_ds))
print("-"*5, "Test Feature", "-"*5)
print("Shape:", test_feature.shape, "Type:", test_feature.dtype, "Max:", test_feature.max())
plt.figure(figsize=(10,5))
plt.title("Test feature")
plt.imshow(test_feature, cmap="gray")

Notice the following:-

1. All three datasets has the same `dtypes`.
2. Features range from **0-255** while targets range from **0-1**.
3. Training and validation dataset are almost identical including their total length and even the indices of the images (no shuffling done), but the images undergo different types of transformation.
4. Test dataset undergoes the same type of transformation as the validation dataset.
5. Test dataset has no targets.
6. The target images have been alterred. The binary mask has now been filled using `skimage.segmentation.ndimage.binary_fill_holes` method. This is done because an object can be easily detected in an image if the object has sufficient contrast from it's background. [_Ref_](https://www.researchgate.net/publication/285371663_Image_Segmentation) 

## Data Splitting

In this section, we will split the training dataset in such a way that `fetal_train_ds` has 8 parts and `fetal_val_ds` has 2 parts.

In [None]:
from sklearn.model_selection import ShuffleSplit
from torch.utils.data import Subset

shuffle_split = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indice_range = range(len(fetal_train_ds))
train_indices, val_indices = next(shuffle_split.split(indice_range))

fetal_train_ds = Subset(fetal_train_ds, train_indices)
fetal_val_ds = Subset(fetal_val_ds, val_indices)

Let's do the same inspection as before, excluding the `fetal_test_ds`.

In [None]:
train_feature, train_target = fetal_train_ds[0]
train_masked = get_masked_image(train_feature, train_target)

val_feature, val_target = fetal_val_ds[0]
val_masked = get_masked_image(val_feature, val_target)

print("Total training features: " ,len(fetal_train_ds))
print("-"*5, "Train Feature", "-"*5)
print("Shape:", train_feature.shape, "Type:", train_feature.dtype, "Max:", train_feature.max())
print("-"*5, "Train Target", "-"*5)
print("Shape:", train_target.shape, "Type:", train_target.dtype, "Max:", train_target.max())
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.title("Train feature")
plt.imshow(train_feature, cmap="gray")
plt.subplot(1,3,2)
plt.title("Train target")
plt.imshow(train_target, cmap="gray")
plt.subplot(1,3,3)
plt.title("Train masked")
plt.imshow(train_masked, cmap="gray")

print("\nTotal validation features: " ,len(fetal_val_ds))
print("-"*5, "Val Feature", "-"*5)
print("Shape:", val_feature.shape, "Type:", val_feature.dtype, "Max:", val_feature.max())
print("-"*5, "Val Target", "-"*5)
print("Shape:", val_target.shape, "Type:", val_target.dtype, "Max:", val_target.max())
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.title("Val feature")
plt.imshow(val_feature, cmap="gray")
plt.subplot(1,3,2)
plt.title("Val target")
plt.imshow(val_target, cmap="gray")
plt.subplot(1,3,3)
plt.title("Val masked")
plt.imshow(val_masked, cmap="gray")

In [None]:
# check if there is any overlap/ data leakage between `fetal_train_ds` and `fetal_val_ds`
# True: no overlap, False: there is overlap
set(fetal_train_ds.indices).intersection(set(fetal_val_ds.indices)) == set()

## Creating DataLoaders

In [None]:
from torch.utils.data import DataLoader

fetal_train_dl = DataLoader(fetal_train_ds, batch_size=8, shuffle=True)
fetal_val_dl = DataLoader(fetal_val_ds, batch_size=8, shuffle=False)

Let's view the properties of the sample batch provided in the first iteration of these dataloader objects.

In [None]:
feature_image_batch, target_image_batch = next(iter(fetal_train_dl))
print("-"*5,"fetal_train_dl", "-"*5 )
print(feature_image_batch.shape, feature_image_batch.dtype)
print(target_image_batch.shape, target_image_batch.dtype)

feature_image_batch, target_image_batch = next(iter(fetal_val_dl))
print("-"*5,"fetal_val_dl", "-"*5 )
print(feature_image_batch.shape, feature_image_batch.dtype)
print(target_image_batch.shape, target_image_batch.dtype)

## Model definition

One of the popular model architecture for segmentation tasks is the so-called **encoder-decoder** model.

<p align="center"><img src="https://miro.medium.com/max/1400/1*44eDEuZBEsmG_TCAKRI3Kw@2x.png" width="500"></img></p>

The first half of the model (encoding section) downsizes the feature map using a few layers of convolutional neural networks (CNN) and pooling layers. In the second half, the feature map is upsampled until it's the same size as the original image size to produce a binary mask. This model was later improved based on the concept of skip connections from `ResNet` to another popular architecture called `U-Net`. 

Let's proceed to define the model class and instantiating the model.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SegNet(nn.Module):
    def __init__(self, params) -> None:
        super().__init__()
        C_in, H_in, W_in = params["input_shape"]
        init_f = params["initial_filters"]
        num_outputs = params["num_outputs"]

        self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(init_f, 2 * init_f, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(2 * init_f, 4 * init_f, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(4 * init_f, 8 * init_f, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(8 * init_f, 16 * init_f, kernel_size=3, padding=1)

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        self.conv_up1 = nn.Conv2d(16 * init_f, 8 * init_f, kernel_size=3, padding=1)
        self.conv_up2 = nn.Conv2d(8 * init_f, 4 * init_f, kernel_size=3, padding=1)
        self.conv_up3 = nn.Conv2d(4 * init_f, 2 * init_f, kernel_size=3, padding=1)
        self.conv_up4 = nn.Conv2d(2 * init_f, init_f, kernel_size=3, padding=1)

        self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv5(x))

        x = self.upsample(x)
        x = F.relu(self.conv_up1(x))

        x = self.upsample(x)
        x = F.relu(self.conv_up2(x))

        x = self.upsample(x)
        x = F.relu(self.conv_up3(x))

        x = self.upsample(x)
        x = F.relu(self.conv_up4(x))

        x = self.conv_out(x)
        return x

In [None]:
# Instantiate the model
model_params = {
    "input_shape":(1,height,width), # the shape of the input data in a tuple-format: (channel, height, width)
    "initial_filters": 16,          # the number of filters for the first CNN layer
    "num_outputs": 1                # only 1 class for single object segmentation (binary)
}
model = SegNet(model_params)

In [None]:
# Move the model to the GPU, if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# UNCOMMENT THE FOLLOWING LINE IF ENFORCE CPU ONLY
# device = torch.device("cpu")

model = model.to(device)

Let's visualize the layers of the model using `print(model)`.

In [None]:
print(model)

We can obtain insights such as the output shape for each layer, the total number of trainable and non-trainable parameters including the estimated size of the model (`Params size`) using a 3rd party package, `torchsummary.summary`.

In [None]:
from torchsummary import summary

summary(model, input_size=(1,height, width), device=device.type)

## Loss function and Optimizer

So far, we already have a dataset loaded in data loader objects and we have already instantiated a model object and transferred the model to our GPU, if it was available. Now, to train this model, we need a loss function and an optimizer to update the model parameters based on the gradients of the loss.

The classical loss function for single-object segmentation is the binary cross-entropy (BCE) loss function. The BCE loss function compares each pixel of the prediction with that of the ground truth; however, we can combine multiple criteria to improve the overall performance of the segmentation tasks. A popular technique is to **combine the dice metric with BCE loss**. The dice metrics is commonly used to test the performance of segmentation algorithms by calculating the amount of overlap between the ground truth and the prediction.

Let's begin by first defining a function to calculate dice metric for a batch of predictions.

In [None]:
def dice_loss_batch(predicted_batch, target_batch, smooth=1e-5):
    """Function to calculate the dice loss per data batch"""
    predicted_batch = torch.sigmoid(predicted_batch) # make all logits positive with range 0-1
    intersection = (predicted_batch * target_batch).sum(dim=(2,3))
    union = predicted_batch.sum(dim=(2,3)) + target_batch.sum(dim=(2,3))
    dice_coeff = 2.0 * (intersection + smooth)/ (union + smooth)
    dice_loss = 1.0 - dice_coeff
    return dice_loss.sum(), dice_coeff.sum()

We can use `dice_coeff` as a validation metric for our model. But our loss is not yet complete until we combine it with BCE loss. Next, let's define another function to calculate the combined loss (dice + BCE).

In [None]:
def loss_function_batch(predicted_batch, target_batch):
    """Function to calculate the combined loss (dice + BCE) and per data batch"""
    bce_loss = F.binary_cross_entropy_with_logits(predicted_batch, target_batch, reduction="sum")
    dice_loss, _ = dice_loss_batch(predicted_batch, target_batch)
    combined_loss = bce_loss + dice_loss
    return combined_loss

Although we could have returned `dice_coeff` as part of the output of the `loss_function_batch`, we are making a design choice by choosing not to. It is more maintainable to have one function to return only loss and another function to return the validation metric, as it may get messy if we were to make changes to our loss function and validation metric in the future.

Let's create a function to return the `dice_coeff` as a validation metric.

In [None]:
def validation_metric_batch(predicted_batch, target_batch):
    """Function to return the validation metric"""
    _, dice_coeff = dice_loss_batch(predicted_batch, target_batch)
    return dice_coeff

Great. So now we have two amazingly helpful functions. 
1. A function to calculate loss for one batch of data
2. A function to calculate the validation metric for one batch of data

Let's incorporate these functions into this `loss_epoch` function to iterate completely through a dataset and calculate the average values for the losses and validation metrics. This function should be designed such a way that it is convenient for us to swap datasets (`DataLoader` objects) and other training parameters. 

In [None]:
from torch import optim

def loss_epoch(
    model: nn.Module,
    loss_function_batch,
    validation_metric_batch,
    dataset_dl: DataLoader,
    optimizer: optim.Optimizer=None,
    device = torch.device("cpu"),
    sanity_check: bool = False, # A flag used to test the training pipeline, trains the model only for a single batch (or single iteration)
):
    running_loss = 0.0
    running_metric = 0.0
    total_dataset_size = len(dataset_dl.dataset)

    for feature_image_batch, target_image_batch in dataset_dl:
        feature_image_batch = feature_image_batch.type(torch.float32).unsqueeze(1).to(device) # Makes the 3D shape input [8,128,192] to 4D shape [8,1,128,192]
        target_image_batch = target_image_batch.type(torch.float32).unsqueeze(1).to(device)
        predicted_image_batch = model(feature_image_batch)

        batch_loss = loss_function_batch(predicted_image_batch, target_image_batch)
        batch_metric = validation_metric_batch(predicted_image_batch, target_image_batch)

        if optimizer is not None:
            optimizer.zero_grad(set_to_none=True) # sets the .grad attribute of each parameter (tensor) to None
            batch_loss.backward() # computes the gradient for every parameter that has `requires_grad=True`
            optimizer.step() # updates the parameters based on their `.grad` attribute

        running_loss += batch_loss.detach()
        running_metric += batch_metric.detach()

        if sanity_check is True:
            break

    epoch_loss = running_loss / float(total_dataset_size) # calculate average loss over all batches
    epoch_metric = running_metric / float(total_dataset_size) # calculate average validation metric over all batches

    # using `.item()` because we will just be collecting the numerical values, not other attributes like `.grad_fn` etc.
    # crucial step to avoid Out Of Memory issues. STORE ONLY WHAT YOU NEED! 
    return epoch_loss.item(), epoch_metric.item()

## Model training

Training a model is an iterative process. In every epoch, we would:-
1. Extract a mini batch from the training dataset and feed into model.
2. Calculate training loss and metric and update the model parameters accordingly.
3. Extract a mini batch from the validation dataset and feed into model.
4. Calculate validation loss and metric and [create a general checkpoint](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html) if there is an improvement to the validation loss compared to the last best.
5. Print all the necessary details derived from the epoch run.
6. Repeat Step 1-5 for the next epoch.

Now that we understand the flow, let's implement this flow in a function, `train_model()`.

In [None]:
import copy

def train_model(model, params):
    num_epochs = params["num_epochs"]
    loss_func = params["loss_func"]
    val_func = params["val_func"]
    optimizer = params["optimizer"]
    train_dl = params["train_dl"]
    val_dl = params["val_dl"]
    sanity_check = params["sanity_check"]
    checkpoint_path = params["checkpoint_path"]
    device = params["device"]

    loss_history = {"train": [], "val": []}
    val_metric_history = {"train": [], "val": []}
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        model.train()
        epoch_train_loss, epoch_train_metric = loss_epoch(model, loss_func, val_func, train_dl, optimizer, device, sanity_check)
        loss_history["train"].append(epoch_train_loss)
        val_metric_history["train"].append(epoch_train_metric)

        model.eval()
        with torch.no_grad():
            epoch_val_loss, epoch_val_metric = loss_epoch(model, loss_func, val_func, val_dl, device=device, sanity_check=sanity_check)

        loss_history["val"].append(epoch_val_loss)
        val_metric_history["val"].append(epoch_val_metric)

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': epoch_val_loss
                }, 
                checkpoint_path)
            print("Saved a new checkpoint for the best model so far!")

        print(f"epoch_train_loss: {epoch_train_loss:.6f} dice: {100*epoch_train_metric:.2f}")
        print(f"epoch_val_loss: {epoch_val_loss:.6f} dice: {100*epoch_val_metric:.2f}")
        print("-" * 10)

    print("Loading best model weights!")
    model.load_state_dict(best_model_wts)
    return model, loss_history, val_metric_history

So far, we had done all the preparation necessary to train our model. Let's not wait any longer and begin training our model.

In [None]:
MODEL_SAVE_BASE = Path("../generated_models").resolve()
if not MODEL_SAVE_BASE.exists(): MODEL_SAVE_BASE.mkdir()

# define parameters for train_model function
num_epochs = 30
optimizer = optim.Adam(model.parameters(), lr=3e-4)
sanity_check = False
checkpoint_path = Path.joinpath(MODEL_SAVE_BASE, f"ENC_DEC_Segmentation_{num_epochs}.pt")

params_train={
    "num_epochs": num_epochs,
    "loss_func": loss_function_batch,
    "val_func": validation_metric_batch,
    "optimizer": optimizer,
    "train_dl": fetal_train_dl,
    "val_dl": fetal_val_dl,
    "sanity_check": sanity_check,
    "checkpoint_path": checkpoint_path,
    "device": device
}

model, loss_history, val_metric_history = train_model(model, params_train)

Let's visualize the historical loss and metric values.

In [None]:
# Plot historical values
plt.figure(figsize=(10,5))
plt.title("Train-Val Loss")
plt.plot(range(1, num_epochs + 1), loss_history["train"], label="train")
plt.plot(range(1, num_epochs + 1), loss_history["val"], label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()

plt.figure(figsize=(10,5))
plt.title("Train-Val Accuracy")
plt.plot(range(1, num_epochs + 1), val_metric_history["train"], label="train")
plt.plot(range(1, num_epochs + 1), val_metric_history["val"], label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()

Let's inspect the saved checkpoint to get more insights about the previous run. These information maybe useful to us for future experiments/ runs.

In [None]:
checkpoint = torch.load(checkpoint_path)
print("Best Epoch:", checkpoint["epoch"])
print("Best Loss:", checkpoint["loss"])

## Model inferencing

Let's assume that we don't have a `model` variable. We'll instantiate a new model to test it's segmenting abilities with the data from the `test_set` folder.

In [None]:
model_new = SegNet(model_params)
model_new.load_state_dict(checkpoint["model_state_dict"])
model_new.eval()
print(model_new)

In [None]:
total_images_to_test = 4
random_indexes = np.random.choice(range(len(fetal_test_ds)),total_images_to_test)

for id in random_indexes:
    feature_image_tensor = fetal_test_ds[id]
    feature_image = feature_image_tensor.numpy()

    # preprocess tensor input to dtype:torch.float32 and shape:from 2D to 4D
    feature_image_tensor = feature_image_tensor.type(torch.float32).view(1,1,128,192)

    with torch.no_grad():
        predicted_image_tensor = model_new(feature_image_tensor)

    # post-process tensor output; shape: from 4D to 2d, make all values positive using sigmoid
    # use thresholding to get a binary mask
    predicted_image = torch.sigmoid(predicted_image_tensor.squeeze()) >= 0.5
    
    masked_image = get_masked_image(feature_image, predicted_image)

    plt.figure()
    plt.subplot(1,3,1)
    plt.imshow(feature_image, cmap="gray")
    plt.title("Feature image")
    plt.subplot(1,3,2)
    plt.imshow(predicted_image, cmap="gray")
    plt.title("Predicted image")
    plt.subplot(1,3,3)
    plt.imshow(masked_image, cmap="gray")
    plt.title("Masked image")