# Project MTI865 - Heart segmentation using UNet 

---

# Model training - CE-DSC with Transformation consistency (MSE) 

$$
\mathcal{L} = \frac{1}{n_{l}} \left(w_{CE} \mathcal{L}_{CE} + w_{DSC} \mathcal{L}_{DSC} \right)  + \frac{\alpha_{TC}}{n_{u}} \mathcal{L}_{TC-MSE}
$$

PAS ENCORE IMPLEMENTÉ !!!! 


## Import libraries

In [1]:
# adding .. to path 
import sys 
sys.path.append('..')

In [2]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2
from progressBar import printProgressBar

import medicalDataLoader
import argparse
import utils

from UNet_Base import *
import random
import torch
import pdb
import matplotlib.pyplot as plt
import numpy as np
import os

In [3]:
import warnings
warnings.filterwarnings("ignore") 

## Loading data 

In [4]:
batch_size = 4
batch_size_val = 4
batch_size_unlabel = 8

In [5]:
# Define image and mask transformations
transform = v2.Compose([
    v2.ToTensor(),
    v2.Normalize(mean=[0.137], std=[0.1733]) # Normalisation values for the training set (mean and std) 
])

mask_transform = v2.Compose([
    v2.ToTensor(),
])

In [6]:
def collate_fn(batch):
    """
    Fonction de regroupement pour le DataLoader.

    Args:
    -----
        batch (list): Liste de tuples (image, masque, chemin de l'image).

    Returns:
    -------
        imgs_tensor (torch.Tensor): Batch d'images.
        masks_tensor (torch.Tensor): Batch de masques.
        img_paths (list): Liste des chemins des images.
    """
    imgs = []
    masks = []
    img_paths = []

    for item in batch:
        img, mask, img_path = item[0], item[1], item[2]
        imgs.append(img)
        img_paths.append(img_path)
        
        # Si le masque est None, ajouter un tenseur de zéros correspondant à sa taille
        if mask is not None:
            masks.append(mask)
        else:
            masks.append(torch.zeros_like(img[0, :, :]))  # Même taille que le canal de l'image (assumant CxHxW)

    # Stack les images et les masques
    imgs_tensor = torch.stack(imgs)  # Tensor de forme (B, C, H, W)
    masks_tensor = torch.stack(masks)  # Tensor de forme (B, H, W)

    return imgs_tensor, masks_tensor, img_paths



In [8]:
# Define dataloaders
root_dir = '../data/'
print(' Dataset: {} '.format(root_dir))

supervised_set = medicalDataLoader.MedicalImageDataset(
    'train',
    root_dir,
    transform=transform,
    mask_transform=mask_transform,
    augment=True,
    equalize=False)


supervised_loader = DataLoader(
    supervised_set,
    batch_size=batch_size,
    worker_init_fn=np.random.seed(0),
    num_workers=0,
    shuffle=True,
    collate_fn=collate_fn)


val_set = medicalDataLoader.MedicalImageDataset(
    'val',  
    root_dir,
    transform=transform,
    mask_transform=mask_transform,
    equalize=False)

val_loader = DataLoader(
    val_set,
    batch_size=batch_size_val,
    worker_init_fn=np.random.seed(0),
    num_workers=0,
    shuffle=False)

unsupervised_set = medicalDataLoader.MedicalImageDataset(
    'train-unlabelled',
    root_dir,
    transform=transform,
    mask_transform=mask_transform,
    augment=False,
    equalize=False)

unsupervised_loader = DataLoader(
    unsupervised_set,
    batch_size=batch_size_unlabel,
    worker_init_fn=np.random.seed(0),
    num_workers=0,
    shuffle=False,
    collate_fn=collate_fn)



print('Train set: ', len(supervised_set))
print('Validation set: ', len(val_set))

n_train_label = len(supervised_set)
n_train_unlabel = len(unsupervised_set)

# shape of the image a  nd mask
img, mask, _ = supervised_set[0]
print('Image shape: ', img.shape)
print('Mask shape: ', mask.shape)
print('Number of batches: ', len(supervised_loader))

img, mask, _ = val_set[0]
print('Image shape: ', img.shape)
print('Mask shape: ', mask.shape)
print('Number of batches: ', len(val_loader))

img, _, __ = unsupervised_set[0]
print('Image shape: ', img.shape)
print('Mask shape: ', mask.shape)
print('Number of batches: ', len(unsupervised_loader))


# print('First of the supervised set')
# img, mask, path_tuple = supervised_set[0]
# print(img)
# print(mask)
# print(path_tuple)

# print('First of the unsupervised set')
# img, mask, path_tuple = unsupervised_set[0]
# print(img)
# print(mask)
# print(path_tuple)




 Dataset: ../data/ 
Found 204 items in train
First item:  ('../data/train\\Img\\patient006_01_1.png', '../data/train\\GT\\patient006_01_1.png')
Found 74 items in val
First item:  ('../data/val\\Img\\patient001_01_1.png', '../data/val\\GT\\patient001_01_1.png')
Found 1004 items in train-unlabelled
First item:  ('../data/train\\Img-Unlabeled\\patient007_01_1.png', None)
Train set:  204
Validation set:  74
Image shape:  torch.Size([1, 256, 256])
Mask shape:  torch.Size([1, 256, 256])
Number of batches:  51
Image shape:  torch.Size([1, 256, 256])
Mask shape:  torch.Size([1, 256, 256])
Number of batches:  19
Image shape:  torch.Size([1, 256, 256])
Mask shape:  torch.Size([1, 256, 256])
Number of batches:  126


## Model using both labeled and unlabeled data 

### Hyperparameters of the model

In [9]:
# Parameters 
lr =  0.0005    # Learning Rate
total_epochs = 150  # Number of epochs
weight_TC = 0.1 # Alpha parameter for the consistency loss term 
weight_decay = 1e-5  # Weight decay
ce_loss_weight = 0.7 # Cross Entropy Loss Weight proportion
dice_loss_weight = 0.3 # Dice Loss Weight proportion 

modelName = f"TransformationConsistencyL2Model-{total_epochs}epochs{lr}lr{weight_TC}alphaTC{weight_decay}wd"
model_dir = f"models/{modelName}"
# write params in a file 
param_dict = {
    "lr": lr,
    "total_epochs": total_epochs,
    "weight_TC": weight_TC,
    "weight_decay": weight_decay,
    "ce_loss_weight": ce_loss_weight,
    "dice_loss_weight": dice_loss_weight,
    "modelName": modelName,
    "model": "ALL"
}


os.makedirs(model_dir, exist_ok=True)
with open(f"{model_dir}/params.txt", 'w') as f:
    print(param_dict, file=f)

print(f"Parameters saved to {model_dir}/params.txt")

print(f"Parameters saved to {model_dir}/params.txt")

Parameters saved to models/TransformationConsistencyL2Model-150epochs0.0005lr0.1alphaTC1e-05wd/params.txt
Parameters saved to models/TransformationConsistencyL2Model-150epochs0.0005lr0.1alphaTC1e-05wd/params.txt


### Transformation consistency regularisation

The transformation consistency consists in the principle that transformation T suchs as rotation and flipping should affect the mask f(y) only by the same rotation, which means that f and T should be symetrical. In this implementation, we used the 2-norm to measure the difference, and we included it in the optimisation problem.  $\mathcal{L}_{TC}(y_u) = \|f(T(y_u))-T(F(y))\|_2$. 
Il est aussi possible de faire une régularisation avec la CE : 

In [10]:
from torchvision import transforms

class ConsistencyRegularization(nn.Module):
    def __init__(self, transformation_fn, loss_fn=nn.MSELoss()):
        """
        Régularisation basée sur la consistance à la transformation.

        Args:
        -----
            transformation_fn (callable): Fonction d'augmentation/transformation appliquée aux images.
            loss_fn (callable): Fonction de perte utilisée pour comparer les prédictions (par défaut MSELoss). Aussi possible d'utiliser 
                                nn.KLDivLoss ou nn.BCELoss.
        """
        super(ConsistencyRegularization, self).__init__()
        self.transformation_fn = transformation_fn
        self.loss_fn = loss_fn

    def forward(self, model, images):
        """
        Calcule la perte de consistance.

        Args:
            model (torch.nn.Module): Le modèle de segmentation.
            images (torch.Tensor): Batch d'images d'entrée.

        Returns:
            torch.Tensor: La perte de consistance.
        """
        with torch.no_grad():
            # Prédictions de base
            original_predictions = F.softmax(model(images), dim=1)

        # Augmenter les images
        augmented_images = self.transformation_fn(images)

        # Prédictions pour les images augmentées
        augmented_predictions = F.softmax(model(augmented_images), dim=1)

        # Calcul de la perte de consistance
        consistency_loss = self.loss_fn(original_predictions, augmented_predictions)

        return consistency_loss
    



In [11]:
def random_transformation_fn(images):
  
    # Random horizontal flip
    if np.random.random() > 0.5:
        images = torch.flip(images, dims=[2])
    # Random vertical flip
    if np.random.random() > 0.5:
        images = torch.flip(images, dims=[3])
    # Random rotation of random angle
    if np.random.random() > 0.5:
        angle = np.random.randint(0, 360)
        images = torch.rot90(images, k=angle//90, dims=[2, 3])
    return images


### Training of the model 

At each epoch, the model sees once every exemple of unlabeled data, and sees several time the labeled data. We first train it with the labeled data, and then we train it on the unsupervised data. 

In [12]:
writer = SummaryWriter()

print("-" * 40)
print("~~~~~~~~  Starting the training... ~~~~~~")
print("-" * 40)

num_classes = 4

# Set device depending on the availability of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.mps.is_available():  # Apple M-series of chips
#     device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
print(" Model Name: {}".format(modelName))

## CREATION OF YOUR MODEL
net = UNet(num_classes).to(device)

print(
    "Total params: {0:,}".format(
        sum(p.numel() for p in net.parameters() if p.requires_grad)
    )
)

# DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
softMax = torch.nn.Softmax(dim=1)
CE_loss = torch.nn.CrossEntropyLoss()
consistency_regularizer = ConsistencyRegularization(transformation_fn=random_transformation_fn)


## PUT EVERYTHING IN GPU RESOURCES
if torch.cuda.is_available():
    net.cuda()
    softMax.cuda()
    CE_loss.cuda()

## DEFINE YOUR OPTIMIZER
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

### To save statistics ####
train_losses = []
train_dc_losses = []
val_losses = []
val_dc_losses = []

best_loss_val = 1000

directory = "Results/Statistics/" + modelName

print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
if os.path.exists(directory) == False:
    os.makedirs(directory)

## START THE TRAINING

## FOR EACH EPOCH
for epoch in range(total_epochs):
    net.train()
    supervised_iter = iter(supervised_loader)
    unsupervised_iter = iter(unsupervised_loader)
    
    num_batches = max(len(supervised_loader), len(unsupervised_loader))
    print("Number of batches: ", num_batches)

    running_train_loss = 0
    running_dice_loss = 0

    # Training loop
    for idx in range(num_batches):
        ### SUPERVISED BATCH
        try :
            supervised_data = next(supervised_iter)
        except StopIteration:
            supervised_iter = iter(supervised_loader)
            supervised_data = next(supervised_iter)

        ### Set to zero all the gradients
        net.zero_grad()
        optimizer.zero_grad()

        ## GET IMAGES, LABELS and IMG NAMES
        images, labels, img_names = supervised_data

        ### From numpy to torch variables
        labels = utils.to_var(labels).to(device)
        images = utils.to_var(images).to(device)

        # Forward pass
        net_predictions = net(images)

        # Get the segmentation classes
        segmentation_classes = utils.getTargetSegmentation(labels)
        # Modify segmentation classes to be one-hot encoded (shape [batch_size, num_classes, height, width])
        dice_target = F.one_hot(segmentation_classes, num_classes = num_classes).permute(0,3,1,2).contiguous()

        # Compute the loss
        ce_loss = ce_loss_weight * CE_loss(net_predictions, segmentation_classes) 
        dice_loss = dice_loss_weight * DiceLoss()(net_predictions, dice_target)  
        loss = ce_loss + dice_loss 
        running_train_loss += ce_loss.item() + dice_loss.item() 
        # dice_loss = dice_coefficient(net_predictions, labels)
        # dice_loss = utils.compute_dsc(net_predictions, labels)
        running_dice_loss += dice_loss

        # Backprop
        # loss.backward()
        # optimizer.step()

        ### UNSUPERVISED BATCH
        try :
            unsupervised_data = next(unsupervised_iter)
        except StopIteration:
            unsupervised_iter = iter(unsupervised_loader)
            unsupervised_data = next(unsupervised_iter)
        
        unsupervised_images, _, __ = unsupervised_data
        unsupervised_images = utils.to_var(unsupervised_images).to(device)

        # net.zero_grad()
        # optimizer.zero_grad()

        consistency_loss = weight_TC * consistency_regularizer(net, unsupervised_images) 
        loss += consistency_loss
        
        loss.backward()
        optimizer.step()

        running_train_loss += consistency_loss.item()
        running_dice_loss += 0

        # Add the loss to the tensorboard every 5 batches
        if idx % 10 == 0:
            writer.add_scalar(
                "Loss/train", running_train_loss / (idx + 1), epoch * len(supervised_loader) + idx
            )
            writer.add_scalar(
                "Dice/train", running_dice_loss / (idx + 1), epoch * len(supervised_loader) + idx
            )

        if idx % 100 == 0:
            # Also add visualizations of the images
            probs = torch.softmax(net_predictions, dim=1)
            y_pred = torch.argmax(probs, dim=1)
            writer.add_figure('predictions vs. actuals',
                        utils.plot_net_predictions(images, labels, y_pred, batch_size),
                        global_step=epoch * len(supervised_loader) + idx)

        # THIS IS JUST TO VISUALIZE THE TRAINING
        printProgressBar(
            idx + 1,
            num_batches,
            prefix="[Training] Epoch: {} ".format(epoch),
            length=15,
            suffix=" Loss: {:.4f}, ".format(running_train_loss / (idx + 1)),
        )
        print(f"Epoch {epoch}, Batch {idx}, CE_loss: {ce_loss.item()}, Dice_loss: {dice_loss.item()}, Consistency_loss: {consistency_loss.item()}")

    train_loss = running_train_loss / num_batches
    train_losses.append(train_loss)

    train_dc_loss = running_dice_loss / num_batches
    train_dc_losses.append(train_dc_loss)
    # print(f"Epoch {epoch}, Batch {idx}, CE_loss: {ce_loss.item()}, Dice_loss: {dice_loss.item()}, Consistency_loss: {consistency_loss.item()}")
    net.eval()
    val_running_loss = 0
    val_running_dc = 0

    # Validation loop
    with torch.no_grad():
        for idx, data in enumerate(val_loader):
            images, labels, img_names = data

            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            net_predictions = net(images)

            segmentation_classes = utils.getTargetSegmentation(labels)

            loss = CE_loss(net_predictions, segmentation_classes) 
            val_running_loss += loss.item()

            # dice_loss = dice_coefficient(net_predictions, labels)
            dice_loss = utils.compute_dsc(net_predictions, labels)
            val_running_dc += dice_loss

            if idx % 10 == 0:
                writer.add_scalar(
                    "Loss/val",
                    val_running_loss / (idx + 1),
                    epoch * len(val_loader) + idx,
                )
                writer.add_scalar(
                    "Dice/val",
                    val_running_dc / (idx + 1),
                    epoch * len(val_loader) + idx,
                )

            printProgressBar(
                idx + 1,
                len(val_loader),
                prefix="[Validation] Epoch: {} ".format(epoch),
                length=15,
                suffix=" Loss: {:.4f}, ".format(val_running_loss / (idx + 1)),
            )

    val_loss = val_running_loss / len(val_loader)
    val_losses.append(val_loss)
    dc_loss = val_running_dc / len(val_loader)
    val_dc_losses.append(dc_loss)

    # Check if model performed best and save it if true
    if val_loss < best_loss_val:
        best_loss_val = val_loss
        if not os.path.exists("./models/" + modelName):
            os.makedirs("./models/" + modelName)
        torch.save(
            net.state_dict(), "./models/" + modelName + "/" + str(epoch) + "_Epoch"
        )

    printProgressBar(
        num_batches,
        num_batches,
        done="[Epoch: {}, TrainLoss: {:.4f}, TrainDice: {:.4f}, ValLoss: {:.4f}".format(
            epoch, train_loss, train_dc_loss, val_loss
        ),
    )

    np.save(os.path.join(directory, "Losses.npy"), train_losses)
writer.flush()  # Flush the writer to ensure that all the data is written to disk
writer.close()

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
Using device: cpu
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name: TransformationConsistencyL2Model-150epochs0.0005lr0.1alphaTC1e-05wd
Total params: 60,664
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
Number of batches:  126
[Training] Epoch: 0 [>              ] 0.8% Loss: 1.4232, Epoch 0, Batch 0, CE_loss: 1.2124372720718384, Dice_loss: 0.20997017621994019, Consistency_loss: 0.0008164049359038472
[Training] Epoch: 0 [>              ] 1.6% Loss: 1.4025, Epoch 0, Batch 1, CE_loss: 1.1755757331848145, Dice_loss: 0.20536848902702332, Consistency_loss: 0.0008156333933584392
[Training] Epoch: 0 [>              ] 2.4% Loss: 1.3952, Epoch 0, Batch 2, CE_loss: 1.1751430034637451, Dice_loss: 0.20456987619400024, Consistency_loss: 0.0009126115473918617
[Training] Epoch: 0 [>              ] 3.2% Loss: 1.3929, Epoch 0, Batch 3, CE_loss: 1.1803746223449707, Dice_los

Exception in thread Thread-6:
Traceback (most recent call last):
  File "c:\Users\nicos\anaconda3\lib\threading.py", line 980, in _bootstrap_inner
    self.run()
  File "c:\Users\nicos\anaconda3\lib\site-packages\tensorboard\summary\writer\event_file_writer.py", line 244, in run
    self._run()
  File "c:\Users\nicos\anaconda3\lib\site-packages\tensorboard\summary\writer\event_file_writer.py", line 275, in _run
    self._record_writer.write(data)
  File "c:\Users\nicos\anaconda3\lib\site-packages\tensorboard\summary\writer\record_writer.py", line 40, in write
    self._writer.write(header + header_crc + data + footer_crc)
  File "c:\Users\nicos\anaconda3\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 775, in write
    self.fs.append(self.filename, file_content, self.binary_mode)
  File "c:\Users\nicos\anaconda3\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 167, in append
    self._write(filename, file_content, "ab" if binary_mode else 

FileNotFoundError: [Errno 2] No such file or directory: b'runs\\Dec16_09-03-18_PCDuDieuDesMathsNico\\events.out.tfevents.1734357798.PCDuDieuDesMathsNico.33924.0'