# Project MTI865 - Heart segmentation using UNet 

---

# Model training - Dice loss and CE Loss 

$$
\mathcal{L} = w_{CE} \mathcal{L}_{CE} + w_{DSC} \mathcal{L}_{DSC}
$$

In [None]:
# adding .. to path 
import sys 
sys.path.append('..')

In [4]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2
from progressBar import printProgressBar

import medicalDataLoader
import argparse
import utils

from UNet_Base import *
import random
import torch
import pdb
import matplotlib.pyplot as plt
import numpy as np
import os


In [5]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Parameters 
batch_size = 8
batch_size_val = 4
lr =  0.003   # Learning Rate
total_epochs = 120  # Number of epochs
weight_decay = 1e-5
ce_loss_weight = 0.7 # Cross Entropy Loss Weight proportion
dice_loss_weight = 0.3 # Dice Loss Weight proportion 
param_dict = {
    "batch_size": batch_size,
    "lr": lr,
    "total_epochs": total_epochs,
    "weight_decay": weight_decay,
    "ce_loss_weight": ce_loss_weight,
    "dice_loss_weight": dice_loss_weight
}

modelName = f"DSC-CE-UNet-{lr}-{total_epochs}-{ce_loss_weight}-{dice_loss_weight}"
model_dir = f"models/{modelName}"
# write params in a file 

os.makedirs(model_dir, exist_ok=True)
with open(f"{model_dir}/params.txt", 'w') as f:
    print(param_dict, file=f)

print(f"Parameters saved to {model_dir}/params.txt")


In [7]:
# Define image and mask transformations
transform = v2.Compose([
    v2.ToTensor(),
    v2.Normalize(mean=[0.137], std=[0.1733]) # Normalisation values for the training set (mean and std) 
])

mask_transform = v2.Compose([
    v2.ToTensor(),
])

In [8]:
# Define dataloaders
root_dir = './data/'
print(' Dataset: {} '.format(root_dir))

train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=True,
                                                    equalize=False)

train_loader_full = DataLoader(train_set_full,
                            batch_size=batch_size,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=True)


val_set = medicalDataLoader.MedicalImageDataset('val',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)

 Dataset: ./data/ 
Found 204 items in train
First item:  ('./data/train\\Img\\patient006_01_1.png', './data/train\\GT\\patient006_01_1.png')
Found 74 items in val
First item:  ('./data/val\\Img\\patient001_01_1.png', './data/val\\GT\\patient001_01_1.png')


In [9]:
def runTraining(writer: SummaryWriter):
    print("-" * 40)
    print("~~~~~~~~  Starting the training... ~~~~~~")
    print("-" * 40)

    num_classes = 4

    # Set device depending on the availability of GPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
    # elif torch.mps.is_available():  # Apple M-series of chips
    #     device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    net = UNet(num_classes).to(device)

    print(
        "Total params: {0:,}".format(
            sum(p.numel() for p in net.parameters() if p.requires_grad)
        )
    )

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    CE_loss = torch.nn.CrossEntropyLoss().to(device)



    ## PUT EVERYTHING IN GPU RESOURCES
    if torch.cuda.is_available():
        net.cuda()
        CE_loss.cuda()
        

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

    ### To save statistics ####
    train_losses = []
    train_dc_losses = []
    val_losses = []
    # val_dc_losses = []

    best_loss_val = 1000

    directory = "Results/Statistics/" + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    ## START THE TRAINING

    ## FOR EACH EPOCH
    for epoch in range(total_epochs):
        net.train()

        num_batches = len(train_loader_full)
        print("Number of batches: ", num_batches)

        running_train_loss = 0
        running_dice_loss = 0

        # Training loop
        for idx, data in enumerate(train_loader_full):
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data

            ### From numpy to torch variables
            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            # Forward pass
            net_predictions = net(images) # Predictions have shape [batch_size, num_classes, height, width]

            # Get the segmentation classes
            segmentation_classes = utils.getTargetSegmentation(labels)
            # Modify segmentation classes to be one-hot encoded (shape [batch_size, num_classes, height, width])
            dice_target = F.one_hot(segmentation_classes, num_classes = num_classes).permute(0,3,1,2).contiguous()

            # Compute the loss
            loss_ce = ce_loss_weight * CE_loss(net_predictions, segmentation_classes)
            loss_dice = dice_loss_weight * DiceLoss()(net_predictions, dice_target)
            loss = loss_ce + loss_dice

            running_train_loss += loss.item()
            # dice_loss = dice_coefficient(net_predictions, labels)
            # dice_loss = utils.compute_dsc(net_predictions, labels)
            # running_dice_loss += dice_loss

            # Backprop
            loss.backward()
            optimizer.step()

            # Add the loss to the tensorboard every 5 batches
            if idx % 10 == 0:
                writer.add_scalar(
                    "Loss/train", running_train_loss / (idx + 1), epoch * len(train_loader_full) + idx
                )

            if idx % 100 == 0:
                # Also add visualizations of the images
                probs = torch.softmax(net_predictions, dim=1)
                y_pred = torch.argmax(probs, dim=1)
                writer.add_figure(f'predictions vs. actuals for {modelName}',
                            utils.plot_net_predictions(images, labels, y_pred, batch_size),
                            global_step=epoch * len(train_loader_full) + idx)

            # THIS IS JUST TO VISUALIZE THE TRAINING
            printProgressBar(
                idx + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch),
                length=15,
                suffix=" Loss: {:.4f}, ".format(running_train_loss / (idx + 1)),
            )

        train_loss = running_train_loss / num_batches
        train_losses.append(train_loss)

        train_dc_loss = running_dice_loss / num_batches
        train_dc_losses.append(train_dc_loss)

        net.eval()
        val_running_loss = 0
        val_running_dc = 0

        # Validation loop
        with torch.no_grad():
            for idx, data in enumerate(val_loader):
                images, labels, img_names = data

                labels = utils.to_var(labels).to(device)
                images = utils.to_var(images).to(device)

                net_predictions = net(images)

                segmentation_classes = utils.getTargetSegmentation(labels)
                dice_target = F.one_hot(segmentation_classes, num_classes = num_classes).permute(0,3,1,2).contiguous()

                loss_ce = ce_loss_weight * CE_loss(net_predictions, segmentation_classes)
                loss_dice = dice_loss_weight * DiceLoss()(net_predictions, dice_target)
                loss = loss_ce + loss_dice

                val_running_loss += loss.item()
                # dice_loss = dice_coefficient(net_predictions, labels)
                # dice_loss = utils.compute_dsc(net_predictions, labels)
                # val_running_dc += dice_loss

                if idx % 10 == 0:
                    writer.add_scalar(
                        "Loss/val",
                        val_running_loss / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )

                printProgressBar(
                    idx + 1,
                    len(val_loader),
                    prefix="[Validation] Epoch: {} ".format(epoch),
                    length=15,
                    suffix=" Loss: {:.4f}, ".format(val_running_loss / (idx + 1)),
                )

        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        # dc_loss = val_running_dc / len(val_loader)
        # val_dc_losses.append(dc_loss)

        # Check if model performed best and save it if true
        if val_loss < best_loss_val:
            best_loss_val = val_loss
            if not os.path.exists("./models/" + modelName):
                os.makedirs("./models/" + modelName)
            torch.save(
                net.state_dict(), "./models/" + modelName + "/" + str(epoch) + "_Epoch"
            )
        
        printProgressBar(
            num_batches,
            num_batches,
            done="[Epoch: {}, TrainLoss: {:.4f}, TrainDice: {:.4f}, ValLoss: {:.4f}".format(
                epoch, train_loss, train_dc_loss, val_loss
            ),
        )

        # Check if loss has not decreased for 10 epochs
        # if epoch > 10 and val_losses[-1] > val_losses[-10]:
        #     print("Stopping early as validation loss has not decreased for 10 epochs")
        #     break

        np.save(os.path.join(directory, "Losses.npy"), train_losses)
    writer.flush()  # Flush the writer to ensure that all the data is written to disk

In [10]:
# Set up Tensorboard writer
writer = SummaryWriter()
runTraining(writer)
writer.close()

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
Using device: cpu
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name: DSC-CE-UNet-0.003-120-0.7-0.3
Total params: 60,664
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
Number of batches:  26
[Training] Epoch: 0 [DONE]                                 
[Validation] Epoch: 0 [DONE]                                 
[Epoch: 0, TrainLoss: 1.0556, TrainDice: 0.0000, ValLoss: 0.8890                                             
Number of batches:  26
[Training] Epoch: 1 [DONE]                                 
[Validation] Epoch: 1 [DONE]                                 
[Epoch: 1, TrainLoss: 0.7220, TrainDice: 0.0000, ValLoss: 0.5998                                             
Number of batches:  26
[Training] Epoch: 2 [DONE]                                 
[Validation] Epoch: 2 [DONE]                                 
[Epoch: 2, TrainLoss: 0.4782, TrainDice: 0.0000