# U-Net Segmentation - Camus Dataset

*Made by **Hang Jung Ling** and **Olivier Bernard** from the INSA Lyon, France.*

This notebook shows how to train, test and evaluate a U-Net to segment different cardiac structures on [CAMUS dataset](https://humanheart-project.creatis.insa-lyon.fr/database/#collection/6373703d73e9f0047faa1bc8).

CAMUS is one of the largest public echocardiogaphic datasets, with 500 patients and each patient has 4 echocardiographic images: end-diastolic (ED) and end-systolic (ES) frames acquired in both apical two chamber and apical four chamber views. Each image is annotated by an expert and contains 3 classes + background:</br>
&emsp;1) Left ventricle</br>
&emsp;2) Myocardium</br>
&emsp;3) Left atrium</br>

Summary :</br>
&emsp;I.   [Install dependencies](#install)</br>
&emsp;II.  [Dataset](#dataset)</br>
&emsp;II.  [Train](#train)</br>
&emsp;III. [Visualize](#visualize)</br>
&emsp;V.   [Evaluate](#evaluate)</br> 


# I. Install dependencies <a class="anchor" id="install"></a>

Kindly ignore this step if you have installed your own environment using `environment.yaml`. If not, please execute the following cells to install the dependencies.

In [None]:
%%capture project_path_setup

import sys

if "../" in sys.path:
    print(sys.path)
else:
    sys.path.append("../")
    print(sys.path)

In [None]:
%%capture packages_install

# Make sure the repo's package and its dependencies are installed
%pip install -e ../.

# II. Dataset <a class="anchor" id="dataset"></a>

Once the environment is successfully setup, download the CAMUS dataset by executing the following cell. The dataset will be downloaded to the `data/` folder.

In [None]:
from pathlib import Path

# Make sure the data is downloaded and extracted where it should be
if not Path("../data/camus_64.zip").is_file():
    !wget "https://www.creatis.insa-lyon.fr/~bernard/camus/camus_64.zip" --directory-prefix="../data/"
    !unzip -qq ../data/camus_64.zip

Now, let's split these data into training, validation and testing sets. We will use 80% of the data for training, 10% for validation and 10% for testing. The split is done by patient ID, so that the same patient will not appear in different sets.

In [None]:
from sklearn.model_selection import train_test_split

from src.utils.file_and_folder_operations import subdirs

# Specify the data directory
data_dir = Path("../data/camus_64").resolve()

# List all the patients id
keys = subdirs(data_dir, prefix="patient", join=False)

# Split the patients into 80/10/10 train/val/test sets
train_keys, val_and_test_keys = train_test_split(keys, train_size=0.8, random_state=12345)
val_keys, test_keys = train_test_split(val_and_test_keys, test_size=0.5, random_state=12345)

train_keys = sorted(train_keys)
val_keys = sorted(val_keys)
test_keys = sorted(test_keys)

# Create train, val and test datalist
viws_instants = ["2CH_ED", "2CH_ES", "4CH_ED", "4CH_ES"]
train_datalist = [
    {
        "image": str(data_dir / key / f"{key}_{view}.nii.gz"),
        "label": str(data_dir / key / f"{key}_{view}_gt.nii.gz"),
    }
    for key in train_keys
    for view in viws_instants
]

val_datalist = [
    {
        "image": str(data_dir / key / f"{key}_{view}.nii.gz"),
        "label": str(data_dir / key / f"{key}_{view}_gt.nii.gz"),
    }
    for key in val_keys
    for view in viws_instants
]

test_datalist = [
    {
        "image": str(data_dir / key / f"{key}_{view}.nii.gz"),
        "label": str(data_dir / key / f"{key}_{view}_gt.nii.gz"),
    }
    for key in test_keys
    for view in viws_instants
]

print("Example of train keys: ", train_datalist[:5])
print("Example of validation keys: ", val_datalist[:5])
print("Example of test keys: ", test_datalist[:5])

Once the data is split, we will create a `Dataset` object for each set. This object will be used to load the data during training and testing.

In [None]:
import numpy as np
from monai.data import CacheDataset
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    LoadImaged,
    NormalizeIntensityd,
    RandAdjustContrastd,
    RandFlipd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandRotated,
    RandScaleIntensityd,
    RandZoomd,
)

# Transforms to load data
load_transforms = [
    LoadImaged(keys=["image", "label"], image_only=True),  # Load image and label
    EnsureChannelFirstd(
        keys=["image", "label"]
    ),  # Make sure the first dimension is the channel dimension
    NormalizeIntensityd(keys=["image"]),  # Normalize the intensity of the image
]

# Transforms to augment data
range_x = [-15.0 / 180 * np.pi, 15.0 / 180 * np.pi]
data_augmentation_transforms = [
    RandRotated(
        keys=["image", "label"],
        range_x=range_x,
        range_y=0,
        range_z=0,
        mode=["bicubic", "nearest"],
        padding_mode="constant",
        prob=0.2,
    ),
    RandZoomd(
        keys=["image", "label"],
        min_zoom=0.7,
        max_zoom=1.4,
        mode=["bicubic", "nearest"],
        padding_mode="constant",
        align_corners=(True, None),
        prob=0.2,
    ),
    RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
    RandGaussianSmoothd(
        keys=["image"],
        sigma_x=(0.5, 1.15),
        sigma_y=(0.5, 1.15),
        prob=0.15,
    ),
    RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
    RandAdjustContrastd(keys=["image"], gamma=(0.7, 1.5), prob=0.3),
    RandFlipd(keys=["image", "label"], spatial_axis=[0], prob=0.5),
]

# Define transforms for training, validation and testing
train_transforms = Compose(load_transforms + data_augmentation_transforms)
val_transforms = Compose(load_transforms)
test_transforms = Compose(load_transforms)

train_ds = CacheDataset(data=train_datalist, transform=train_transforms, cache_rate=1.0)
val_ds = CacheDataset(data=val_datalist, transform=val_transforms, cache_rate=1.0)
test_ds = CacheDataset(data=test_datalist, transform=test_transforms, cache_rate=1.0)

Now, let's visualize some images from the training set. The images are displayed with their corresponding ground truth segmentation masks.

In [None]:
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

from src.utils.visualizations import imagesc

# Get a random image with label from each dataset
train_idx = np.random.randint(len(train_ds))
val_idx = np.random.randint(len(val_ds))
test_idx = np.random.randint(len(test_ds))

print("train_idx: ", train_idx)
print("val_idx: ", val_idx)
print("test_idx: ", test_idx)

# Visualize a random image with label from each dataset
colors = ["black", "red", "green", "blue"]
cmap = ListedColormap(colors)

figure = plt.figure(figsize=(8, 8))
train_sample = train_ds[train_idx]
image = train_sample["image"].detach().cpu().numpy()[0].transpose(1, 0)
label = train_sample["label"].detach().cpu().numpy()[0].transpose(1, 0)
ax = figure.add_subplot(3, 2, 1)
imagesc(ax, image, title="Training image", show_colorbar=False)
ax = figure.add_subplot(3, 2, 2)
plt.imshow(label, cmap=cmap, interpolation="nearest")
plt.title("Training label")
ax.axis("off")

val_sample = val_ds[val_idx]
image = val_sample["image"].detach().cpu().numpy()[0].transpose(1, 0)
label = val_sample["label"].detach().cpu().numpy()[0].transpose(1, 0)
ax = figure.add_subplot(3, 2, 3)
imagesc(ax, image, title="Validation image", show_colorbar=False)
ax = figure.add_subplot(3, 2, 4)
plt.imshow(label, cmap=cmap, interpolation="nearest")
plt.title("Validation label")
ax.axis("off")

test_sample = test_ds[test_idx]
image = test_sample["image"].detach().cpu().numpy()[0].transpose(1, 0)
label = test_sample["label"].detach().cpu().numpy()[0].transpose(1, 0)
ax = figure.add_subplot(3, 2, 5)
imagesc(ax, image, title="Test image", show_colorbar=False)
ax = figure.add_subplot(3, 2, 6)
plt.imshow(label, cmap=cmap, interpolation="nearest")
plt.title("Test label")
ax.axis("off")
figure.tight_layout()
plt.show()

# III. Train <a class="anchor" id="train"></a>
Let's move on to train a U-Net to segment the left ventricle, myocardium and left atrium. We will use the training and validation sets created in the previous section.

### Definition of U-Net architecture

In [None]:
from torchinfo import summary

from src.models.unet import UNet

input_channels = 1  # This is the number of input channels in the image
input_shape = (input_channels, 64, 64)  # This is the shape of the input image to the network
output_channels = 4  # This is the number of output classes
output_shape = (output_channels, 64, 64)  # This is the shape of the output mask
init_channels = 32  # This is the number of channels in the first layer of the network

unet = UNet(input_shape=input_shape, output_shape=output_shape, init_channels=init_channels)

# Print the summary of the network
summary_kwargs = dict(
    col_names=["input_size", "output_size", "kernel_size", "num_params"], depth=3, verbose=0
)
summary(unet, (1, *input_shape), device="cpu", **summary_kwargs)

### Definition of optimizer, loss function, and metrics

In [None]:
from functools import partial

import torch
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric

# Soft dice and CE loss function
loss_function = DiceCELoss(
    include_background=False,
    batch=True,
    smooth_nr=0.00001,
    smooth_dr=0.00001,
    lambda_dice=0.5,
    lambda_ce=0.5,
)

# Adam optimizer
optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)

# Hard dice metric
metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

### Definition of training function

In [None]:
import time

from monai.data import DataLoader, decollate_batch
from monai.transforms import AsDiscrete


def train_process(
    train_ds,
    val_ds,
    model,
    device,
    loss_function,
    optimizer,
    metric,
    max_epochs,
    log_dir,
    val_interval=1,
):
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=8,
    )
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
    model = model.to(device)
    loss_function = DiceCELoss(
        include_background=False,
        batch=True,
        smooth_nr=0.00001,
        smooth_dr=0.00001,
        lambda_dice=0.5,
        lambda_ce=0.5,
    )
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=0.9,
        weight_decay=0.00004,
    )

    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=model.output_shape[0])])
    post_label = Compose([AsDiscrete(to_onehot=model.output_shape[0])])

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    epoch_times = []
    total_start = time.time()
    for epoch in range(max_epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step_start = time.time()
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size},"
                f" train_loss: {loss.item():.4f}"
                f" step time: {(time.time() - step_start):.4f}"
            )
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = model(val_inputs)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(),
                        str(log_dir / "best_metric_model.pth"),
                    )
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current"
                    f" mean dice: {metric:.4f}"
                    f" best mean dice: {best_metric:.4f}"
                    f" at epoch: {best_metric_epoch}"
                )
        print(f"time of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
        epoch_times.append(time.time() - epoch_start)

    print(
        f"train completed, best_metric: {best_metric:.4f}"
        f" at epoch: {best_metric_epoch}"
        f" total time: {(time.time() - total_start):.4f}"
    )
    return (
        max_epochs,
        time.time() - total_start,
        epoch_loss_values,
        metric_values,
        epoch_times,
    )