In [None]:
import sys
sys.path.append('..')
import os
import shutil
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import nibabel as nib
import json

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
import torchmetrics

from utils.Task1_utils import (
    Preprocess_ACDC,
    load_config,
    ACDC_Dataset,
    make_serializeable_metrics,
    skin_plot
)

# Task1.1: Build a Vanilla U-Net

<img src="../img/U-Net.png" width="600" height="400">

Please complete the architecture according to the figure in UNet.

**Contracting path (Encoder containing downsampling steps):**

Images are first fed through several convolutional layers which reduce height and width, while growing the number of channels.

The contracting path follows a regular CNN architecture, with convolutional layers, their activations, and pooling layers to downsample the image and extract its features. In detail, it consists of the repeated application of two 3 x 3 unpadded convolutions, each followed by a rectified linear unit (ReLU) and a 2 x 2 max pooling operation with stride 2 for downsampling. At each downsampling step, the number of feature channels is doubled.

Crop function: This step crops the image from the contracting path and concatenates it to the current image on the expanding path to create a skip connection.

**Expanding path (Decoder containing upsampling steps):**

The expanding path performs the opposite operation of the contracting path, growing the image back to its original size, while shrinking the channels gradually.

In detail, each step in the expanding path upsamples the feature map, followed by a 2 x 2 convolution (the transposed convolution). This transposed convolution halves the number of feature channels, while growing the height and width of the image.

Next is a concatenation with the correspondingly cropped feature map from the contracting path, and two 3 x 3 convolutions, each followed by a ReLU. You need to perform cropping to handle the loss of border pixels in every convolution.

**Final Feature Mapping Block:** 

In the final layer, a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes. The channel dimensions from the previous layer correspond to the number of filters used, so when you use 1x1 convolutions, you can transform that dimension by choosing an appropriate number of 1x1 filters. When this idea is applied to the last layer, you can reduce the channel dimensions to have one layer per class.

> Ronneberger, Olaf, et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” arXiv.Org, 18 May 2015, https://arxiv.org/abs/1505.04597v1.

The code for block of double convolution is provided.


In [None]:
# Set the seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)

## Implement Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=False):
        super().__init__()
        if batch_norm:
            self.step = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            )
        else:
            self.step = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(),
            )
        
    def forward(self, x):
        return self.step(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, init_channels=64, batch_norm=False):
        super().__init__()

        # Hint: Use DoubleConv() module
        # Encoder part
        # todo

        # Add a bottleneck layer (this is the bottleneck)
        # todo
        
        # Decoder part
        # todo

        # Pooling and upsampling
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
    
    def forward(self, x):
        # Hint: torch.cat([...], dim=1) to concatenate tensors along channel dimension
        # Encoding path
        # todo
        
        # Bottleneck
        # todo
        
        # Decoding path
        # todo
        
        return None

## Prepare The Dataset

Now we will run this U-Net on a public dataset called [Automated Cardiac Diagnosis Challenge(ACDC)](https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html).


The ACDC dataset includes cine MRI scans of the heart in the short-axis view. It provides masked data for the segmentation of various cardiac structures, such as the left ventricle, right ventricle, and myocardium. The dataset consists of 100 subjects for training and 50 subjects for testing, each categorized into five cardiac pathologies and a healthy group. The primary goal is to assess automated methods for cardiac diagnosis.

> O. Bernard, A. Lalande, C. Zotti, F. Cervenansky, et al.
"Deep Learning Techniques for Automatic MRI Cardiac Multi-structures Segmentation and
Diagnosis: Is the Problem Solved ?" in IEEE Transactions on Medical Imaging,
vol. 37, no. 11, pp. 2514-2525, Nov. 2018

In [None]:
data_folder = "../data"

# todo: Create a subfolder in `data` folder that should contain 150 subfolders, each corresponding to a patient
# todo: You need to put all subjects in `training` and `testing` folder into this same folder.
# For example, if you make a folder data/ACDC_all, which has the following structure:
# └── patient001
# ├── Info.cfg
# ├── MANDATORY_CITATION.md
# ├── patient001_4d.nii.gz
# ├── patient001_frame01_gt.nii.gz
# ├── patient001_frame01.nii.gz
# ├── patient001_frame12_gt.nii.gz
# └── patient001_frame12.nii.gz
# Then you should set ACDC_raw_folder = os.path.join(data_folder, "ACDC_all")
ACDC_raw_folder = None

We can make a plot to get a glimpse of the data.

In [None]:
image1_data = nib.load(os.path.join(ACDC_raw_folder, "patient002/patient002_frame01.nii.gz")).get_fdata()
mask1_data = nib.load(os.path.join(ACDC_raw_folder, "patient002/patient002_frame01_gt.nii.gz")).get_fdata()

image1 = image1_data[:, :, 0]
mask1 = mask1_data[:, :, 0]
mask1 = np.where(mask1 == 0, np.nan, mask1)

plt.imshow(image1, cmap='gray')
plt.imshow(mask1, cmap='jet', alpha=0.5)

The red part is the left ventricle, the blue part is the right ventricle, and the green part is the myocardium.

In [None]:
ACDC_raw_folder = os.path.join(data_folder, "ACDC_all")

len_dict = Preprocess_ACDC(ACDC_raw_folder)
print(len_dict)
np_data_folder = os.path.join(data_folder, "np_data")

Now images and corresponding masks are stored in `data/images`, `data/masks`. The npy files are stored in `data/np_data/`, which are the ones we will use next.

In [None]:
# Define some parameters we will use later
config = load_config("../config/Task1.1_config.yaml")

NUM_CLASSES = config["dataset"]["number_classes"]
INPUT_SIZE = config["dataset"]["input_size"]
TRAIN_PARAMS = config["train"]
MODEL_PARAMS = config["model"]

In [None]:
# preparing training, validation and testing dataset
train_dataset = ACDC_Dataset(
    mode="train",
    data_dir=np_data_folder,
    len_dict=len_dict,
    one_hot=True,
    num_classes=NUM_CLASSES,
)
val_dataset = ACDC_Dataset(
    mode="val",
    data_dir=np_data_folder,
    len_dict=len_dict,
    one_hot=True,
    num_classes=NUM_CLASSES,
)
test_dataset = ACDC_Dataset(
    mode="test",
    data_dir=np_data_folder,
    len_dict=len_dict,
    one_hot=True,
    num_classes=NUM_CLASSES,
)

In [None]:
train_dataloader = DataLoader(train_dataset, **config['data_loader']['train'])
val_dataloader = DataLoader(val_dataset, **config['data_loader']['val'])
test_dataloader = DataLoader(test_dataset, **config['data_loader']['test'])

We need some metrics to evaluate the performance of our model. Please implement the Dice coefficient and Jaccard index (IoU) using `torchmetrics.MetricCollection`.

In [None]:
metrics = torchmetrics.MetricCollection(
    [
        # todo Implement Dice and JaccardIndex using torchmetrics
        # Hint: Don't forget to set ignore_index argument
    ],
    prefix='metrics/'
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Torch device: {device}")

train_metrics = metrics.clone(prefix='train_metrics/').to(device)
val_metrics = metrics.clone(prefix='val_metrics/').to(device)
test_metrics = metrics.clone(prefix='test_metrics/').to(device)

In [None]:
def validate(model, device, val_dataloader, criterion):
    """
    Function to validate the model. We will use this function in the `train` function.
    """
    model.eval()
    with torch.no_grad():
        
        val_evaluator = val_metrics.clone().to(device)
        val_losses = []
        cnt = 0

        for idx, batch_data in enumerate(val_dataloader):
            images = batch_data['image']
            masks = batch_data['mask']            
            images = images.to(device)
            masks = masks.to(device)
            preds = model(images)

            cnt += images.shape[0]

            loss = criterion(preds, masks)
            val_losses.append(loss.item())
            
            # todo: Implement code so that val_evaluator can calculate the metrics
            # Hint: It can be done within three lines of code
        
            if idx == len(val_dataloader) - 1:
                # todo: Implement descriptions. 
                # Hint: Use numpy for loss and train_evaluator for dice and IOU
                _loss = f"curr-loss:{None:0.5f}"
                _dice = f"dice:{None:0.5f}"
                _iou = f"iou:{None:0.5f}"
    
                print(f"Validation) -> {_loss}, {_dice}, {_iou}")
        
        val_loss = np.sum(val_losses)/cnt
    
    return val_evaluator, val_loss

In [None]:
def train(
    model,
    device,
    train_dataloader,
    val_dataloader,
    criterion,
    optimizer,
    scheduler
):
    EPOCHS = TRAIN_PARAMS["epochs"]
    torch.cuda.empty_cache()
    model = model.to(device)

    train_evaluator = train_metrics.clone().to(device)
    train_iterator = tqdm(total=EPOCHS)

    epochs_info = []
    best_model = None
    best_result = {}
    best_val_loss = np.inf
    for epoch in range(EPOCHS):
        model.train()

        train_evaluator.reset()
        train_losses = []
        cnt = 0
        for idx, batch_data in enumerate(train_dataloader):

            images = batch_data["image"]
            masks = batch_data["mask"]
            images = images.to(device)
            masks = masks.to(device)
            preds = model(images)

            cnt += images.shape[0]

            optimizer.zero_grad()  # clear existing gradients
            loss = criterion(preds, masks)
            loss.backward()
            train_losses.append(loss.item())
            optimizer.step()

            # todo: Implement code so that train_evaluator can calculate the metrics
            # Hint: This should be the same as the case in validate()

            if idx == len(train_dataloader) - 1:
                # todo: Implement descriptions. Hint: Use numpy for loss and train_evaluator for dice and IOU
                # Hint: You only need to modify the prefix of train_metrics
                _loss = f"curr-loss:{None:0.5f}"
                _dice = f"dice:{None:0.5f}"
                _iou = f"iou:{None:0.5f}"

                train_iterator.update(1)
                print(
                    f"\nTraining) ep:{epoch+1:03d}/{EPOCHS} -> {_loss}, {_dice}, {_iou}"
                )

        train_loss = np.sum(train_losses) / cnt

        val_evaluator, val_loss = validate(model, device, val_dataloader, criterion)
        if val_loss < best_val_loss:
            best_model = model
            best_val_loss = val_loss
            best_result = {
                "train_loss": train_loss,
                "val_loss": val_loss,
                "train_metrics": make_serializeable_metrics(train_evaluator.compute()),
                "val_metrics": make_serializeable_metrics(val_evaluator.compute()),
            }

        epoch_info = {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_metrics": make_serializeable_metrics(train_evaluator.compute()),
            "val_metrics": make_serializeable_metrics(val_evaluator.compute()),
        }
        epochs_info.append(epoch_info)
        train_evaluator.reset()

        scheduler.step(val_loss)

    # Save the info of each epoch and the best result
    if os.path.exists(MODEL_PARAMS["save_dir"]):
        shutil.rmtree(MODEL_PARAMS["save_dir"], ignore_errors=True)
    os.makedirs(MODEL_PARAMS["save_dir"], exist_ok=True)
    train_info = {
        "epochs_info": epochs_info,
        "best_result": best_result,
    }
    file_name = "Train_info.json"
    file_path = os.path.join(MODEL_PARAMS["save_dir"], file_name)
    with open(file_path, "w") as write_file:
        json.dump(train_info, write_file, indent=4)

    # save last model's parameters
    file_name = "last_model_state_dict.pt"
    file_path = os.path.join(MODEL_PARAMS["save_dir"], file_name)
    torch.save(model.state_dict(), file_path)

    # save best model's parameters
    file_name = "best_model_state_dict.pt"
    file_path = os.path.join(MODEL_PARAMS["save_dir"], file_name)
    torch.save(best_model.state_dict(), file_path)

    return best_model, model, train_info

In [None]:
def test(model, device, test_dataloader):
    model.eval()
    with torch.no_grad():
        test_evaluator = test_metrics.clone().to(device)            
        for batch_data in tqdm(test_dataloader):
            images = batch_data['image']
            masks = batch_data['mask']
            images = images.to(device)
            masks = masks.to(device)
            
            preds = model(images)
            
            # todo: Implement code so that train_evaluator can calculate the metrics
            # Hint: This should be the same as the case in validate()
            
    return test_evaluator

In [None]:
model = UNet(**MODEL_PARAMS["params"])
torch.cuda.empty_cache()
model = model.to(device)
print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

We define some essential elements before formally training the U-Net Model.

In [None]:
# Set Optimizer to Adam with learning rate 0.001
optimizer = globals()[TRAIN_PARAMS["optimizer"]["name"]](
    model.parameters(), lr=TRAIN_PARAMS["optimizer"]["lr"]
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", **TRAIN_PARAMS["scheduler"]
)
# During the training, we use cross entropy loss as the criterion
def criterion(preds, masks):
    # ignore_index should not be set, otherwise background will not be correctly predicted
    cross_entropy_loss = CrossEntropyLoss()
    return cross_entropy_loss(preds, masks)

In [None]:
best_model, last_model, train_info = train(
    model,
    device,
    train_dataloader,
    val_dataloader,
    criterion,
    optimizer,
    scheduler
)
print(f"Best and last models are saved to path {MODEL_PARAMS['save_dir']}")

In [None]:
best_model = UNet(**MODEL_PARAMS["params"])
torch.cuda.empty_cache()
best_model = best_model.to(device)

file_name = "best_model_state_dict.pt"
file_path = os.path.join(MODEL_PARAMS["save_dir"], file_name)

if device.type == "cpu":
    best_model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu'), weights_only=True))
else:
    best_model.load_state_dict(torch.load(file_path, weights_only=True))

In [None]:
test_evaluator = test(best_model, device, test_dataloader)
print(test_evaluator.compute())

Finally, we make some visualization of the model's performance as well as some sample segmentation results.

In [None]:
train_info_path = f"{MODEL_PARAMS['save_dir']}/Train_info.json"
with open(train_info_path, "r") as f:
    train_info = json.loads("".join(f.readlines()))

epochs_info = train_info["epochs_info"]
train_losses = [d["train_loss"] for d in epochs_info]
val_losses = [d["val_loss"] for d in epochs_info]

# todo: Implement code to extract train and validation dice and Jaccard index for all epoechs
# Hint: Each object should be a list, instead of a single value
train_dice = None
val_dice = None
train_jaccard = None
val_jaccard = None


_, axs = plt.subplots(1, 3, figsize=[15, 5])

axs[0].set_title("Loss")
axs[0].plot(train_losses, "r-", label="train loss")
axs[0].plot(val_losses, "b-", label="validatiton loss")
axs[0].legend()

axs[1].set_title("Dice score")
axs[1].plot(train_dice, "r-", label="train dice")
axs[1].plot(val_dice, "b-", label="validation dice")
axs[1].legend()

axs[2].set_title("Jaccard Similarity")
axs[2].plot(train_jaccard, "r-", label="train JaccardIndex")
axs[2].plot(val_jaccard, "b-", label="validatiton JaccardIndex")
axs[2].legend()

plt.savefig(f"{MODEL_PARAMS['save_dir']}/metrics.png")

In [None]:
save_imgs_dir = f"{MODEL_PARAMS['save_dir']}/visualized"

if os.path.isdir(save_imgs_dir):
    shutil.rmtree(save_imgs_dir)
os.mkdir(save_imgs_dir)

title_dict = {1: "RV", 2: "Myo", 3: "LV"}  # Right ventricle, myocardium, left ventricle

# Visualize one batch of test data
with torch.no_grad():
    for batch in test_dataloader:
        ids = batch['id']
        images = batch['image']
        masks = batch['mask']
        images = images.to(device)
        preds = best_model(images)

        images_np = images.cpu().numpy()
        masks_np = masks.cpu().numpy()
        preds_np = torch.argmax(preds, 1).cpu().numpy()

        for idx in range(len(masks_np)):
            for value in range(1, NUM_CLASSES):
                if value in np.unique(masks_np[idx]) and value in np.unique(preds_np[idx]):
                    image = np.moveaxis(images_np[idx, :3], 0, -1)*255.
                    image = np.ascontiguousarray(image, dtype=np.uint8)
                    mask = np.where(masks_np[idx] == value, 255, 0)
                    mask = np.ascontiguousarray(mask, dtype=np.uint8)
                    pred = np.where(preds_np[idx] == value, 255, 0)
                    pred = np.ascontiguousarray(pred, dtype=np.uint8)

                    plot = skin_plot(image, mask, pred)

                    file_id = ids[idx]
                    plt.clf()
                    plt.title(f"Comparison of {title_dict[value]} for sample {file_id}")
                    plt.imshow(plot)
                    plt.legend()
                    plt.savefig(f"{save_imgs_dir}/Sample_{file_id}_img_{title_dict[value]}.png")
                    plt.close()
        break

In the visualized plots, the red contours represent the predictions and the blue contours represent the ground truth.