# Project MTI865 - Heart segmentation using UNet 

---

# Model evaluation 

## Importing libraries 

In [51]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2
from progressBar import printProgressBar

import medicalDataLoader
import argparse
import utils

from UNet_Base import *
import random
import torch
import pdb
import matplotlib.pyplot as plt
import numpy as np
import os

from sklearn import metrics as skmetrics
from scipy import stats 

import warnings
warnings.filterwarnings("ignore") 


## Loading data 

### Batch sizes

In [52]:
batch_size = 1
batch_size_val = 1
batch_size_unlabel = 1

### Mask and image transformation

In [53]:
# Define image and mask transformations
transform = v2.Compose([
    v2.ToTensor()
])

mask_transform = v2.Compose([
    v2.ToTensor()
])

### Data loaders 

In [54]:
def collate_fn(batch):
    imgs = []
    masks = []
    img_paths = []

    for item in batch:
        img, mask, img_path = item[0], item[1], item[2]
        imgs.append(img)
        img_paths.append(img_path)
        
        # Si le masque est None, ajouter un tenseur de zéros correspondant à sa taille
        if mask is not None:
            masks.append(mask)
        else:
            masks.append(torch.zeros_like(img[0, :, :]))  # Même taille que le canal de l'image (assumant CxHxW)

    # Stack les images et les masques
    imgs_tensor = torch.stack(imgs)  # Tensor de forme (B, C, H, W)
    masks_tensor = torch.stack(masks)  # Tensor de forme (B, H, W)

    return imgs_tensor, masks_tensor, img_paths

In [None]:
# Define dataloaders
root_dir = './data/'
print(' Dataset: {} '.format(root_dir))

supervised_set = medicalDataLoader.MedicalImageDataset('train',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=True,
                                                    equalize=False)


supervised_loader = DataLoader(
    supervised_set,
    batch_size=batch_size,
    worker_init_fn=np.random.seed(0),
    num_workers=0,
    shuffle=True,
    collate_fn=collate_fn)


val_set = medicalDataLoader.MedicalImageDataset('val',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)

unsupervised_set = medicalDataLoader.MedicalImageDataset('train-unlabelled',
                                                            root_dir,
                                                            transform=transform,
                                                            mask_transform=mask_transform,
                                                            augment=False,
                                                            equalize=False)
# print(train_unlabelled_set.imgs)
# train_unlabelled_set = [(img) for img, mask in train_unlabelled_set]
unsupervised_loader = DataLoader(unsupervised_set,
                                    batch_size=batch_size_unlabel,
                                    worker_init_fn=np.random.seed(0),
                                    num_workers=0,
                                    shuffle=False,
                                    collate_fn=collate_fn)

test_set = medicalDataLoader.MedicalImageDataset('test',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=True,
                                                    equalize=False)

test_loader = DataLoader(
    test_set,
    batch_size=batch_size_unlabel,
    worker_init_fn=np.random.seed(0),
    num_workers=0,
    shuffle=False,
    collate_fn=collate_fn)


# Let's print the first batch to understand the data

for loader in [supervised_loader, val_loader, unsupervised_loader, test_loader]:
    imgs, masks, img_paths = next(iter(loader))
    print('Images shape: ', imgs.shape)
    print('Masks shape: ', masks.shape)
    # print('Image paths: ', img_paths)



## Loading a model 

### Loading the parameters

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
# elif torch.mps.is_available():  # Apple M-series of chips
#     device = torch.device("mps")
else:
    device = torch.device("cpu")

epoch_to_load = 96
model = UNet(4).to(device=device)
modelName = 'Test_Model'
model.load_state_dict(torch.load(f"./models/{modelName}/{epoch_to_load}_Epoch"))

### Visual comparison 

In [None]:
data_iterator = iter(val_loader)
num_batch = len(val_loader)
print('Number of batches: ', num_batch)

for i in range(num_batch):
    img, mask, path_img = next(data_iterator)
    print(path_img)
    img = utils.to_var(img)
    mask = utils.to_var(mask)

    print(f"Image : shape {img.shape}, stats : {stats.describe(img.flatten().cpu().numpy())}")
    img = img.to(device=device).detach()
    # print(f"Image : shape {img.shape}, stats : {stats.describe(img.flatten().cpu().numpy())}")
    mask = mask.to(device=device).detach()
    # print(f"Mask : shape {mask.shape}, stats : {stats.describe(mask.flatten().cpu().numpy())}")
    pred = model(img).detach()
    print(f"Pred : shape {pred.shape}, stats : {stats.describe(pred.flatten().cpu().numpy())}")

    probs = torch.softmax(pred, dim=1).detach()
    print(f"Probs : shape {probs.shape}, stats : {stats.describe(probs.flatten().cpu().numpy())}")
    y_pred = torch.argmax(probs, dim=1).detach()
    print(f"y_pred : shape {y_pred.shape}, stats : {stats.describe(y_pred.flatten().cpu().numpy())}") 

    y_true = utils.getTargetSegmentation(mask)
    print(f"y_true : shape {y_true.shape}, stats : {stats.describe(y_true.flatten().cpu().numpy())}")
    print("="*20)

    # Convert predictions and true values to numpy arrays 
    y_pred = y_pred.cpu().numpy()[0]
    y_true = y_true.cpu().numpy()

    # print(f"y_pred : shape {y_pred.shape}")
    # print(f"y_true : shape {y_true.shape}")

    fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 
    ax[0].imshow(img[0, 0, :, :], cmap='gray')
    ax[0].set_title('Image')
    ax[1].imshow(y_true, cmap='gray')
    ax[1].set_title('Ground Truth')
    ax[2].imshow(y_pred, cmap='gray')
    ax[2].set_title('Prediction')
    plt.show()
    if i> 5:
        break

### Building the confusion matrix for one image 

In [None]:
data_iterator = iter(val_loader)
num_batch = len(val_loader)
print('Number of batches: ', num_batch)

for i in range(num_batch):
    img, mask, _ = next(data_iterator)
    print(f"Image : shape {img.shape}, stats : {stats.describe(img.flatten().cpu().numpy())}")
    img = img.to(device=device).detach()
    # print(f"Image : shape {img.shape}, stats : {stats.describe(img.flatten().cpu().numpy())}")
    mask = mask.to(device=device).detach()
    # print(f"Mask : shape {mask.shape}, stats : {stats.describe(mask.flatten().cpu().numpy())}")
    pred = model(img).detach()
    print(f"Pred : shape {pred.shape}, stats : {stats.describe(pred.flatten().cpu().numpy())}")

    probs = torch.softmax(pred, dim=1).detach()
    # print(f"Probs : shape {probs.shape}, stats : {stats.describe(probs.flatten().cpu().numpy())}")
    y_pred = torch.argmax(probs, dim=1).detach()
    print(f"y_pred : shape {y_pred.shape}, stats : {stats.describe(y_pred.flatten().cpu().numpy())}") 

    y_true = utils.getTargetSegmentation(mask)
    # print(f"y_true : shape {y_true.shape}, stats : {stats.describe(y_true.flatten().cpu().numpy())}")

    # Convert predictions and true values to numpy arrays 
    y_pred = y_pred.cpu().numpy().flatten()
    y_true = y_true.cpu().numpy().flatten()

    print(f"y_pred : shape {y_pred.shape}, stats : {stats.describe(y_pred)}")
    print(f"y_true : shape {y_true.shape}, stats : {stats.describe(y_true)}")
    try : 
        confusionMatrix += skmetrics.confusion_matrix(y_true, y_pred)
    except :
        confusionMatrix = skmetrics.confusion_matrix(y_true, y_pred)

print(confusionMatrix)

normalizedConfusionMatrix = confusionMatrix / confusionMatrix.sum(axis=1)[:, np.newaxis]
print(normalizedConfusionMatrix)