Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.


# Cross validation and models ensemble to achieve better test metrics

Cross validation and models ensemble are popular strategies in machine learning and deep learning areas to achieve more accurate and more stable outputs.  
A typical practice is:
* Split all the training dataset into K folds.
* Train K models with every K-1 folds data (cross validation).
* Execute inference on the test data with all the K models.
* Compute the average values with weights or vote the most common value as the final result (ensemble).
<p>
<img src="../figures/models_ensemble.png" width="80%" alt='models_ensemble'>
</p>

MONAI provides `CrossValidation` dataset to automatically split K-folds for cross validation.

And also provides `EnsembleEvaluator` and `MeanEnsemble`, `VoteEnsemble` post transforms for models ensemble.

This tutorial shows how to leverage cross validation and ensemble modules in MONAI to set up a program.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/cross_validation_models_ensemble.ipynb)

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[ignite, nibabel, tqdm]"

## Setup imports

In [None]:
from abc import ABC, abstractmethod
import logging
import os
import tempfile
import shutil
import sys

import nibabel as nib
import numpy as np
import torch

from monai.apps import CrossValidation
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, create_test_image_3d
from monai.engines import EnsembleEvaluator, SupervisedEvaluator, SupervisedTrainer
from monai.handlers import MeanDice, StatsHandler, ValidationHandler, from_engine
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
    Activationsd,
    EnsureChannelFirstd,
    AsDiscreted,
    Compose,
    LoadImaged,
    MeanEnsembled,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    VoteEnsembled,
)
from monai.utils import set_determinism

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/workspace/data/medical


## Set determinism, logging, device

In [3]:
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = torch.device("cuda:0")

## Generate random (image, label) pairs

Generate 60 pairs for the task, 50 for training and 10 for test.  
And then split the 50 pairs into 5 folds to train 5 separate models.

In [18]:
data_dir = os.path.join(root_dir, "runs")
datalist = []

if not os.path.exists(data_dir):
    os.makedirs(data_dir)
for i in range(60):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

    n = nib.Nifti1Image(im, np.eye(4))
    img_path = os.path.join(data_dir, f"img{i}.nii.gz")
    nib.save(n, img_path)

    n = nib.Nifti1Image(seg, np.eye(4))
    seg_path = os.path.join(data_dir, f"seg{i}.nii.gz")
    nib.save(n, seg_path)

    datalist.append({"image": img_path, "label": seg_path})

## Create the base class for cross validation datasets
Here we leverage the `monai.apps.CrossValidation` to automatically split the datasets for cross validation.

In [7]:
class CVDataset(ABC, CacheDataset):
    """
    Base class to generate cross validation datasets.

    """

    def __init__(
        self,
        data,
        transform,
        cache_num=sys.maxsize,
        cache_rate=1.0,
        num_workers=4,
    ) -> None:
        data = self._split_datalist(datalist=data)
        CacheDataset.__init__(
            self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
        )

    @abstractmethod
    def _split_datalist(self, datalist):
        raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

## Setup transforms for training and validation

In [8]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=[96, 96, 96],
            pos=1,
            neg=1,
            num_samples=4,
        ),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
    ]
)

## Get Datasets and DataLoaders for train, validation and test

In [None]:
num = 5
folds = list(range(num))

cvdataset = CrossValidation(
    dataset_cls=CVDataset,
    data=datalist[0:50],
    nfolds=5,
    seed=12345,
    transform=train_transforms,
)

train_dss = [cvdataset.get_dataset(folds=folds[0:i] + folds[(i + 1) :]) for i in folds]
val_dss = [cvdataset.get_dataset(folds=i, transform=val_transforms) for i in range(num)]

train_loaders = [DataLoader(train_dss[i], batch_size=2, shuffle=True, num_workers=4) for i in folds]
val_loaders = [DataLoader(val_dss[i], batch_size=1, num_workers=4) for i in folds]

test_ds = CacheDataset(data=datalist[50:], transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4)

## Define a training process based on workflows

More usage examples of MONAI workflows are available at: [workflow examples](https://github.com/Project-MONAI/tutorials/tree/main/modules/engines).

In [22]:
def train(index):
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    val_post_transforms = Compose(
        [EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5)]
    )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loaders[index],
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=from_engine(["pred", "label"]),
            )
        },
    )
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=4, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=4,
        train_data_loader=train_loaders[index],
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        amp=False,
        train_handlers=train_handlers,
    )
    trainer.run()
    return net

## Execute 5 training processes and get 5 models

In [None]:
models = [train(i) for i in range(num)]

## Define evaluation process based on `EnsembleEvaluator`

In [25]:
def ensemble_evaluate(post_transforms, models):
    evaluator = EnsembleEvaluator(
        device=device,
        val_data_loader=test_loader,
        pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
        networks=models,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=post_transforms,
        key_val_metric={
            "test_mean_dice": MeanDice(
                include_background=True,
                output_transform=from_engine(["pred", "label"]),
            )
        },
    )
    evaluator.run()

## Evaluate the ensemble result with `MeanEnsemble`

`EnsembleEvaluator` accepts a list of models for inference and outputs a list of predictions for further operations.

Here the input data is a list or tuple of PyTorch Tensor with shape: [B, C, H, W, D].  
The list represents the output data from 5 models.  
And `MeanEnsemble` also can support to add `weights` for the input data:
* The `weights` will be added to input data from highest dimension.
* If the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.
* If the `weights` has 3 dimensions, it will be added to `E`, `B` and `C` dimensions.  
For example, to ensemble 3 segmentation model outputs, every output has 4 channels(classes),  
The input data shape can be: [3, B, 4, H, W, D], and add different `weights` for different classes.  
So the `weights` shape can be: [3, 1, 4], like:  
`weights = [[[1, 2, 3, 4]], [[4, 3, 2, 1]], [[1, 1, 1, 1]]]`.

In [26]:
mean_post_transforms = Compose(
    [
        EnsureTyped(keys=["pred0", "pred1", "pred2", "pred3", "pred4"]),
        MeanEnsembled(
            keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
            output_key="pred",
            # in this particular example, we use validation metrics as weights
            weights=[0.95, 0.94, 0.95, 0.94, 0.90],
        ),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold=0.5),
    ]
)
ensemble_evaluate(mean_post_transforms, models)

INFO:ignite.engine.engine.EnsembleEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs
INFO:ignite.engine.engine.EnsembleEvaluator:Got new best metric of test_mean_dice: 0.9329929351806641
INFO:ignite.engine.engine.EnsembleEvaluator:Epoch[1] Complete. Time taken: 00:00:15
INFO:ignite.engine.engine.EnsembleEvaluator:Engine run complete. Time taken: 00:00:15


## Evaluate the ensemble result with `VoteEnsemble`

Here the input data is a list or tuple of PyTorch Tensor with shape: [B, C, H, W, D].  
The list represents the output data from 5 models.

Note that:
* `VoteEnsemble` expects the input data is discrete values.
* Input data can be multiple channels data in One-Hot format or single channel data.
* It will vote to select the most common data between items.
* The output data has the same shape as every item of the input data.

In [27]:
vote_post_transforms = Compose(
    [
        EnsureTyped(keys=["pred0", "pred1", "pred2", "pred3", "pred4"]),
        Activationsd(keys=["pred0", "pred1", "pred2", "pred3", "pred4"], sigmoid=True),
        # transform data into discrete before voting
        AsDiscreted(keys=["pred0", "pred1", "pred2", "pred3", "pred4"], threshold=0.5),
        VoteEnsembled(keys=["pred0", "pred1", "pred2", "pred3", "pred4"], output_key="pred"),
    ]
)
ensemble_evaluate(vote_post_transforms, models)

INFO:ignite.engine.engine.EnsembleEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs
INFO:ignite.engine.engine.EnsembleEvaluator:Got new best metric of test_mean_dice: 0.9344435930252075
INFO:ignite.engine.engine.EnsembleEvaluator:Epoch[1] Complete. Time taken: 00:00:15
INFO:ignite.engine.engine.EnsembleEvaluator:Engine run complete. Time taken: 00:00:16


## Cleanup data directory

Remove directory if a temporary was used.

In [13]:
if directory is None:
    shutil.rmtree(root_dir)