In [None]:
!pip install nibabel monai matplotlib torch wandb scipy pandas numpy medpy
import wandb
wandb.login()

Defaulting to user installation because normal site-packages is not writeable


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtijmenl[0m ([33mtijmenl-universiteit-twente[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
#%% Adapted Script for Heart MRI Segmentation using a 2D UNet with Cross-Validation and Learning Rate Scheduling
import os
import glob
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import monai
from monai.transforms import (
    Compose, EnsureChannelFirstd, ScaleIntensityd, Resized, RandZoomd, RandFlipd, RandRotated,
    Rand2DElasticd, RandAdjustContrastd, RandGaussianSmoothd, RandGaussianNoised, RandShiftIntensityd
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.networks.utils import one_hot
from monai.losses import DiceCELoss
from sklearn.model_selection import KFold
from tqdm import tqdm
import wandb

#%% Utility function for loading NIfTI images
def load_nii(img_path):
    nimg = nib.load(img_path)
    return nimg.get_fdata(), nimg.affine, nimg.header

#%% Build dataset dictionary
def build_dict(data_path):
    image_dir = os.path.join(data_path, "training", "image")
    mask_dir = os.path.join(data_path, "training", "segmentation")
    image_paths = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
    mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*_gt.nii.gz")))
    mask_dict = {os.path.basename(m).replace("_gt", ""): m for m in mask_paths}
    dataset_dicts = []
    for img_path in image_paths:
        filename = os.path.basename(img_path)
        patient_number = filename.split('_')[0]
        mask_path = mask_dict.get(filename, None)
        if mask_path and os.path.exists(mask_path):
            dataset_dicts.append({"patient": patient_number, "img": img_path, "mask": mask_path})
    return dataset_dicts

#%% Custom Transform to Load All Slices of Data
class LoadHeartData(monai.transforms.Transform):
    def __call__(self, sample):
        img_vol, _, _ = load_nii(sample['img'])
        mask_vol, _, _ = load_nii(sample['mask'])
        images = np.moveaxis(img_vol, -1, 0)
        masks = np.moveaxis(mask_vol, -1, 0)
        slice_list = []
        for i in range(images.shape[0]):
            slice_list.append({
                'img': images[i].astype(np.float32),
                'mask': masks[i].astype(np.uint8),
                'img_meta_dict': {'affine': np.eye(2)},
                'mask_meta_dict': {'affine': np.eye(2)}
            })
        return slice_list

#%% Set data path
main_path = r'./database'

#%% Build dataset dictionary
dataset_dicts = build_dict(main_path)

#%% Define transforms
transforms = Compose([
    LoadHeartData(),
    EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
    ScaleIntensityd(keys=['img']),
    Resized(keys=['img', 'mask'], spatial_size=(256, 256), mode=['bilinear', 'nearest']),

    RandZoomd(keys=['img', 'mask'], min_zoom=0.90, max_zoom=1.10, mode=['bilinear', 'nearest'], prob=0.5),
    RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=1),
    RandRotated(keys=['img', 'mask'], range_x=0.1, range_y=0.1, mode=['bilinear', 'nearest'], prob=0.5),
    Rand2DElasticd(keys=['img', 'mask'], spacing=(5, 5), magnitude_range=(0, 0.1), prob=0.5, mode=['bilinear', 'nearest']),
    RandAdjustContrastd(keys=['img'], gamma=(0.7, 1.3), prob=0.3),
    RandGaussianSmoothd(keys=['img'], sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), prob=0.3),
    RandGaussianNoised(keys=['img'], prob=0.3, mean=0.0, std=0.05),
    RandShiftIntensityd(keys=['img'], prob=0.5, offsets=(10,20))
])

#%% Flatten dataset to handle all slices
def flatten_dataset(dataset_list, transform):
    flat_list = []
    for data in dataset_list:
        flat_list.extend(transform(data))
    return flat_list

full_dataset = flatten_dataset(dataset_dicts * 3, transforms)
full_dataset = np.array(full_dataset)  # Convert to numpy for indexing in cross-validation

#%% Cross-Validation Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 200
k_folds = 5
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

fold_results = []
dice_metric = DiceMetric(include_background=True, reduction="mean")

#%% Cross-validation loop
for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset)):
    print(f"Training Fold {fold + 1}/{k_folds}...")

    # Split dataset for this fold
    train_subset = full_dataset[train_idx].tolist()
    val_subset = full_dataset[val_idx].tolist()

    train_dataset = monai.data.Dataset(data=train_subset)
    val_dataset = monai.data.Dataset(data=val_subset)

    train_dataloader = DataLoader(train_dataset, batch_size=4, num_workers=2, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    class UNetWithDropout(monai.networks.nets.UNet):
        def __init__(self, spatial_dims, in_channels, out_channels, channels, strides, num_res_units, dropout_prob=0.5):
            super().__init__(
                spatial_dims=spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                channels=channels,
                strides=strides,
                num_res_units=num_res_units
            )

            # Adding dropout after each block
            self.dropout = nn.Dropout(p=dropout_prob)

        def forward(self, x):
            x = super().forward(x)
            x = self.dropout(x)
            return x

    # Improved U-Net Model with increased capacity
    model = UNetWithDropout(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        dropout_prob=0.5
    ).to(device)

    # Loss function & optimizer
    loss_function = DiceCELoss(to_onehot_y=True, softmax=True, include_background=True)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Learning Rate Scheduler: reduce LR if validation loss plateaus
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

    # Initialize WandB for each fold
    run = wandb.init(
        project='ACDC 2D Unet Cross-Validation',
        name=f'Fold_{fold + 1}',
        config={'fold': fold + 1, 'batch_size': train_dataloader.batch_size}
    )

    # Variables for early stopping and best model tracking
    patience = 20  # Number of epochs to wait for improvement before stopping
    epochs_without_improvement = 0  # Counter for epochs without improvement
    best_val_loss = float("inf")
    best_dice_score = -float("inf")
    best_epoch = 0
    best_model_wts = None

    # Training loop for this fold
    for epoch in tqdm(range(num_epochs)):
        model.train()
        train_loss_epoch = 0.0
        for batch_data in train_dataloader:
            images = batch_data["img"].to(device)
            masks = batch_data["mask"].to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_function(outputs, masks)
            loss.backward()
            optimizer.step()
            train_loss_epoch += loss.item() * images.size(0)
        train_loss = train_loss_epoch / len(train_dataloader.dataset)

        # Validation phase
        model.eval()
        val_loss_epoch = 0.0
        dice_metric.reset()
        with torch.no_grad():
            for batch_data in val_dataloader:
                images = batch_data["img"].to(device)
                masks = batch_data["mask"].to(device)
                outputs = model(images)
                loss = loss_function(outputs, masks)
                val_loss_epoch += loss.item() * images.size(0)

                # Convert outputs to segmentation map and then one-hot encode
                outputs = torch.argmax(outputs, dim=1, keepdim=True)  # [B, 1, H, W]
                outputs_onehot = one_hot(outputs, num_classes=4)
                masks = masks.unsqueeze(1)  # Ensure shape is [B, 1, H, W]
                masks_onehot = one_hot(masks, num_classes=4)
                dice_metric(y_pred=outputs_onehot, y=masks_onehot)

        val_loss = val_loss_epoch / len(val_dataloader.dataset)
        dice_score = dice_metric.aggregate().item()

        wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'dice_score': dice_score})

        # Step the scheduler with validation loss
        scheduler.step(val_loss)

        # Early stopping: Check if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            # Update best model checkpoint based on lowest validation loss
            best_epoch = epoch
            best_model_wts = model.state_dict()
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping at epoch {epoch + 1} due to no improvement in validation loss.")
                break  # Stop training if no improvement in validation loss for `patience` epochs

    # Save the best model for this fold
    torch.save(best_model_wts, f'bestHeartUNet_Fold{fold + 1}.pt')
    fold_results.append({"val_loss": best_val_loss, "dice_score": best_dice_score})
    run.finish()

#%% Function to load ensemble models
def load_ensemble_models(model_paths, device):
    models = []
    for path in model_paths:
        model = UNet(
                spatial_dims=2,
                in_channels=1,
                out_channels=4,
                channels=(32, 64, 128, 256, 512),
                strides=(2, 2, 2, 2),
                num_res_units=2,
            ).to(device)
        model.load_state_dict(torch.load(path, map_location=device))
        model.eval()
        models.append(model)
        return models

#%% Define the EnsembleModel globally to avoid pickle issues
class EnsembleModel(torch.nn.Module):
    def __init__(self, models):
        super(EnsembleModel, self).__init__()
        self.models = torch.nn.ModuleList(models)

    def forward(self, x):
        # Collect and average the predictions (soft voting)
        outputs_list = [torch.softmax(model(x), dim=1) for model in self.models]
        avg_outputs = torch.mean(torch.stack(outputs_list), dim=0)
        return avg_outputs

#%% Save the ensemble model
def save_ensemble_model(model_paths, device, save_path):
    ensemble_models = load_ensemble_models(model_paths, device)
    ensemble_model = EnsembleModel(ensemble_models)

    # Save the entire model instead of just state_dict
    torch.save(ensemble_model, save_path)
    print(f"Averaged Ensemble Model saved at: {save_path}")
    return ensemble_model

#%% Main Execution
model_paths = [os.path.join(f"bestHeartUNet_Fold{i+1}.pt") for i in range(5)]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Save the entire ensemble model
ensemble_model_path = "ensembleHeartUNet.pt"
save_ensemble_model(model_paths, device, ensemble_model_path)

2025-03-31 20:26:00.305408: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-31 20:26:00.342816: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-31 20:26:00.342834: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-31 20:26:00.342855: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-31 20:26:00.350422: I tensorflow/core/platform/cpu_feature_g

Training Fold 1/5...


The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.


 39%|███▉      | 78/200 [1:05:49<1:42:57, 50.63s/it]

Early stopping at epoch 79 due to no improvement in validation loss.





0,1
dice_score,▁▇▇▆▇██▇████████████████████████████████
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇██
train_loss,█▆▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▄▃▃▃▂▁▂▁▁▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
dice_score,0.90356
epoch,78.0
train_loss,1.09942
val_loss,0.18812


Training Fold 2/5...


 30%|██▉       | 59/200 [45:50<1:49:44, 46.70s/it]