# Project MTI865 - Heart segmentation using UNet 

---

# Model training 

$$
\mathcal{L} = \mathcal{L}_{CE} 
$$

## Import libraries

In [2]:
# adding .. to path 
import sys 
sys.path.append('..')

In [3]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2
from progressBar import printProgressBar

import medicalDataLoader
import argparse
import utils

from UNet_Base import *
import random
import torch
import pdb
import matplotlib.pyplot as plt
import numpy as np
import os

In [3]:
import warnings
warnings.filterwarnings("ignore") 

## Loading data 

In [4]:
batch_size = 4
batch_size_val = 4
batch_size_unlabel = 8

In [5]:
# Define image and mask transformations
transform = v2.Compose([
    v2.ToTensor()
])

mask_transform = v2.Compose([
    v2.ToTensor()
])

In [6]:
def collate_fn(batch):
    imgs = []
    masks = []
    img_paths = []

    for item in batch:
        img, mask, img_path = item[0], item[1], item[2]
        imgs.append(img)
        img_paths.append(img_path)
        
        # Si le masque est None, ajouter un tenseur de zéros correspondant à sa taille
        if mask is not None:
            masks.append(mask)
        else:
            masks.append(torch.zeros_like(img[0, :, :]))  # Même taille que le canal de l'image (assumant CxHxW)

    # Stack les images et les masques
    imgs_tensor = torch.stack(imgs)  # Tensor de forme (B, C, H, W)
    masks_tensor = torch.stack(masks)  # Tensor de forme (B, H, W)

    return imgs_tensor, masks_tensor, img_paths



In [7]:
# Define dataloaders
root_dir = './data/'
print(' Dataset: {} '.format(root_dir))

supervised_set = medicalDataLoader.MedicalImageDataset('train',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=True,
                                                    equalize=False)


supervised_loader = DataLoader(
    supervised_set,
    batch_size=batch_size,
    worker_init_fn=np.random.seed(0),
    num_workers=0,
    shuffle=True,
    collate_fn=collate_fn)


val_set = medicalDataLoader.MedicalImageDataset('val',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)

unsupervised_set = medicalDataLoader.MedicalImageDataset('train-unlabelled',
                                                            root_dir,
                                                            transform=transform,
                                                            mask_transform=mask_transform,
                                                            augment=False,
                                                            equalize=False)
# print(train_unlabelled_set.imgs)
# train_unlabelled_set = [(img) for img, mask in train_unlabelled_set]
unsupervised_loader = DataLoader(unsupervised_set,
                                    batch_size=batch_size_unlabel,
                                    worker_init_fn=np.random.seed(0),
                                    num_workers=0,
                                    shuffle=False,
                                    collate_fn=collate_fn)



print('Train set: ', len(supervised_set))
print('Validation set: ', len(val_set))

n_train_label = len(supervised_set)
n_train_unlabel = len(unsupervised_set)

# shape of the image a  nd mask
img, mask, _ = supervised_set[0]
print('Image shape: ', img.shape)
print('Mask shape: ', mask.shape)
print('Number of batches: ', len(supervised_loader))

img, mask, _ = val_set[0]
print('Image shape: ', img.shape)
print('Mask shape: ', mask.shape)
print('Number of batches: ', len(val_loader))

img, _, __ = unsupervised_set[0]
print('Image shape: ', img.shape)
print('Mask shape: ', mask.shape)
print('Number of batches: ', len(unsupervised_loader))


print('First of the supervised set')
img, mask, path_tuple = supervised_set[0]
print(img)
print(mask)
print(path_tuple)

print('First of the unsupervised set')
img, mask, path_tuple = unsupervised_set[0]
print(img)
print(mask)
print(path_tuple)




 Dataset: ./data/ 
Found 204 items in train
First item:  ('./data/train\\Img\\patient006_01_1.png', './data/train\\GT\\patient006_01_1.png')
Found 74 items in val
First item:  ('./data/val\\Img\\patient001_01_1.png', './data/val\\GT\\patient001_01_1.png')
Found 1004 items in train-unlabelled
First item:  ('./data/train\\Img-Unlabeled\\patient007_01_1.png', None)
Train set:  204
Validation set:  74
Image shape:  torch.Size([1, 256, 256])
Mask shape:  torch.Size([1, 256, 256])
Number of batches:  51
Image shape:  torch.Size([1, 256, 256])
Mask shape:  torch.Size([1, 256, 256])
Number of batches:  19
Image shape:  torch.Size([1, 256, 256])
Mask shape:  torch.Size([1, 256, 256])
Number of batches:  126
First of the supervised set
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
te

## Modèle basique : Entraînement avec les GT uniquement 

### Paramètres de l'entraînement

In [11]:
# Parameters 
lr =  0.001    # Learning Rate
total_epochs = 150  # Number of epochs
weight_decay = 1e-5  # Weight decay
modelName = f"Default_UNet-{total_epochs}epochs-{lr}lr{weight_decay}wd" 

# save parameters
param_dict = {
    'model':'Default_UNet',
    'lr': lr,
    'total_epochs': total_epochs,
    'weight_decay': weight_decay,
    'modelName': modelName, 
    'batch_size': batch_size,
    'batch_size_val': batch_size_val,
    'batch_size_unlabel': batch_size_unlabel
}
model_dir = f"models/{modelName}"
# write params in a file 

os.makedirs(model_dir, exist_ok=True)
with open(f"{model_dir}/params.txt", 'w') as f:
    print(param_dict, file=f)

print(f"Parameters saved to {model_dir}/params.txt")

FileNotFoundError: [Errno 2] No such file or directory: './models/Default_UNet-150epochs-0.001lr1e-05wd/params.txt'

### Entraînement du modèle 

In [28]:
def runTraining(writer: SummaryWriter):
    print("-" * 40)
    print("~~~~~~~~  Starting the training... ~~~~~~")
    print("-" * 40)

    num_classes = 4

    # Set device depending on the availability of GPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
    # elif torch.mps.is_available():  # Apple M-series of chips
    #     device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    
    print(" Model Name: {}".format(modelName))

    # Create the model
    net = UNet(num_classes).to(device)

    print(
        "Total params: {0:,}".format(
            sum(p.numel() for p in net.parameters() if p.requires_grad)
        )
    )

    # Define the loss function
    softMax = torch.nn.Softmax(dim=1)
    CE_loss = torch.nn.CrossEntropyLoss()

    ## PUT EVERYTHING IN GPU RESOURCES
    # if torch.cuda.is_available():
    #     net.cuda()
    #     softMax.cuda()
    #     CE_loss.cuda()

    ## OPTIMIZER
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

    ### To save statistics ####
    train_losses = []
    train_dc_losses = []
    val_losses = []
    val_dc_losses = []

    best_loss_val = 1000

    directory = "Results/Statistics/" + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    ## START THE TRAINING

    ## FOR EACH EPOCH
    for epoch in range(total_epochs):
        net.train()

        num_batches = len(supervised_loader)
        print("Number of batches: ", num_batches)

        running_train_loss = 0
        running_dice_loss = 0

        # Training loop
        for idx, data in enumerate(supervised_loader):
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data
            # print("Type of labels before to_var:", type(labels))
            ### From numpy to torch variables
            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            # Forward pass
            net_predictions = net(images)

            # Get the segmentation classes
            segmentation_classes = utils.getTargetSegmentation(labels)

            # Compute the loss
            loss = CE_loss(net_predictions, segmentation_classes)
            running_train_loss += loss.item()
            # dice_loss = dice_coefficient(net_predictions, labels)
            dice_loss = utils.compute_dsc(net_predictions, labels)
            running_dice_loss += dice_loss

            # Backprop
            loss.backward()
            optimizer.step()

            # Add the loss to the tensorboard every 5 batches
            if idx % 10 == 0:
                writer.add_scalar(
                    "Loss/train", running_train_loss / (idx + 1), epoch * len(supervised_loader) + idx
                )
                writer.add_scalar(
                    "Dice/train", running_dice_loss / (idx + 1), epoch * len(supervised_loader) + idx
                )

            if idx % 100 == 0:
                # Also add visualizations of the images
                probs = torch.softmax(net_predictions, dim=1)
                y_pred = torch.argmax(probs, dim=1)
                writer.add_figure(f'predictions vs. actuals {modelName}',
                            utils.plot_net_predictions(images, labels, y_pred, batch_size),
                            global_step=epoch * len(supervised_loader) + idx)

            # THIS IS JUST TO VISUALIZE THE TRAINING
            printProgressBar(
                idx + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch),
                length=15,
                suffix=" Loss: {:.4f}, ".format(running_train_loss / (idx + 1)),
            )

        train_loss = running_train_loss / num_batches
        train_losses.append(train_loss)

        train_dc_loss = running_dice_loss / num_batches
        train_dc_losses.append(train_dc_loss)

        net.eval()
        val_running_loss = 0
        val_running_dc = 0

        # Validation loop
        with torch.no_grad():
            for idx, data in enumerate(val_loader):
                images, labels, img_names = data

                labels = utils.to_var(labels).to(device)
                images = utils.to_var(images).to(device)

                net_predictions = net(images)

                segmentation_classes = utils.getTargetSegmentation(labels)

                loss = CE_loss(net_predictions, segmentation_classes) 
                val_running_loss += loss.item()

                # dice_loss = dice_coefficient(net_predictions, labels)
                dice_loss = utils.compute_dsc(net_predictions, labels)
                val_running_dc += dice_loss

                if idx % 10 == 0:
                    writer.add_scalar(
                        "Loss/val",
                        val_running_loss / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )
                    writer.add_scalar(
                        "Dice/val",
                        val_running_dc / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )

                printProgressBar(
                    idx + 1,
                    len(val_loader),
                    prefix="[Validation] Epoch: {} ".format(epoch),
                    length=15,
                    suffix=" Loss: {:.4f}, ".format(val_running_loss / (idx + 1)),
                )

        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        dc_loss = val_running_dc / len(val_loader)
        val_dc_losses.append(dc_loss)

        # Check if model performed best and save it if true
        if val_loss < best_loss_val:
            best_loss_val = val_loss
            if not os.path.exists("./models/" + modelName):
                os.makedirs("./models/" + modelName)
            torch.save(
                net.state_dict(), "./models/" + modelName + "/" + str(epoch) + "_Epoch"
            )

        printProgressBar(
            num_batches,
            num_batches,
            done="[Epoch: {}, TrainLoss: {:.4f}, TrainDice: {:.4f}, ValLoss: {:.4f}".format(
                epoch, train_loss, train_dc_loss, val_loss
            ),
        )

        np.save(os.path.join(directory, "Losses.npy"), train_losses)
    writer.flush()  # Flush the writer to ensure that all the data is written to disk

In [29]:
# Set up Tensorboard writer
writer = SummaryWriter()
runTraining(writer)
writer.close()

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
Using device: cpu
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name: Default_UNet
Total params: 60,664
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
Number of batches:  51
[Training] Epoch: 0 [DONE]                                 
[Validation] Epoch: 0 [DONE]                                 
[Epoch: 0, TrainLoss: 1.0959, TrainDice: 0.0544, ValLoss: 0.9406                                             
Number of batches:  51
[Training] Epoch: 1 [DONE]                                 
[Validation] Epoch: 1 [DONE]                                 
[Epoch: 1, TrainLoss: 0.8382, TrainDice: 0.0544, ValLoss: 0.7349                                             
Number of batches:  51

KeyboardInterrupt: 