# Models ensemble to achieve better test metrics

Models ensemble is a popular strategy in machine learning and deep learning areas to achieve more accurate and more stable outputs.  
A typical practice is:
* Split all the training dataset into K folds.
* Train K models with every K-1 folds data.
* Execute inference on the test data with all the K models.
* Compute the average values with weights or vote the most common value as the final result.
<p>
<img src="../figures/models_ensemble.png" width="80%" alt='models_ensemble'>
</p>

MONAI provides `EnsembleEvaluator` and `MeanEnsemble`, `VoteEnsemble` post transforms.  
This tutorial shows how to leverage ensemble modules in MONAI to set up ensemble program.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/models_ensemble.ipynb)

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[ignite, nibabel, tqdm]"

## Setup imports

In [2]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import logging
import os
import tempfile
import shutil
import sys

import nibabel as nib
import numpy as np
import torch

from monai.config import print_config
from monai.data import CacheDataset, DataLoader, create_test_image_3d
from monai.engines import (
    EnsembleEvaluator,
    SupervisedEvaluator,
    SupervisedTrainer
)
from monai.handlers import MeanDice, StatsHandler, ValidationHandler, from_engine
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
    Activationsd,
    AsChannelFirstd,
    AsDiscreted,
    Compose,
    LoadImaged,
    MeanEnsembled,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    VoteEnsembled,
)
from monai.utils import set_determinism

print_config()

MONAI version: 0.6.0rc1+23.gc6793fd0
Numpy version: 1.20.3
Pytorch version: 1.9.0a0+c3d40fd
MONAI flags: HAS_EXT = True, USE_COMPILED = False
MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c

Optional dependencies:
Pytorch Ignite version: 0.4.5
Nibabel version: 3.2.1
scikit-image version: 0.15.0
Pillow version: 7.0.0
Tensorboard version: 2.5.0
gdown version: 3.13.0
TorchVision version: 0.10.0a0
ITK version: 5.1.2
tqdm version: 4.53.0
lmdb version: 1.2.1
psutil version: 5.8.0
pandas version: 1.1.4
einops version: 0.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/workspace/data/medical


## Set determinism, logging, device

In [4]:
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = torch.device("cuda:0")

## Generate random (image, label) pairs

Generate 60 pairs for the task, 50 for training and 10 for test.  
And then split the 50 pairs into 5 folds to train 5 separate models.

In [5]:
data_dir = os.path.join(root_dir, "runs")

if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    for i in range(60):
        im, seg = create_test_image_3d(
            128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(data_dir, f"img{i}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(data_dir, f"seg{i}.nii.gz"))

images = sorted(glob.glob(os.path.join(data_dir, "img*.nii.gz")))
segs = sorted(glob.glob(os.path.join(data_dir, "seg*.nii.gz")))

train_files = []
val_files = []
for i in range(5):
    train_files.append(
        [
            {"image": img, "label": seg}
            for img, seg in zip(
                images[: (10 * i)] + images[(10 * (i + 1)): 50],
                segs[: (10 * i)] + segs[(10 * (i + 1)): 50],
            )
        ]
    )
    val_files.append(
        [
            {"image": img, "label": seg}
            for img, seg in zip(images[(10 * i): (10 * (i + 1))],
                                segs[(10 * i): (10 * (i + 1))])
        ]
    )

test_files = [{"image": img, "label": seg}
              for img, seg in zip(images[50:60], segs[50:60])]

## Setup transforms for training and validation

In [6]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=[96, 96, 96],
            pos=1,
            neg=1,
            num_samples=4,
        ),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
    ]
)

## Define CacheDatasets and DataLoaders for train, validation and test

In [7]:
num_models = 5
train_dss = [CacheDataset(
    data=train_files[i],
    transform=train_transforms) for i in range(num_models)]
train_loaders = [
    DataLoader(
        train_dss[i], batch_size=2, shuffle=True, num_workers=4)
    for i in range(num_models)
]

val_dss = [CacheDataset(data=val_files[i], transform=val_transforms)
           for i in range(num_models)]
val_loaders = [DataLoader(val_dss[i], batch_size=1, num_workers=4)
               for i in range(num_models)]

test_ds = CacheDataset(data=test_files, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4)

100%|██████████| 40/40 [00:01<00:00, 26.37it/s]
100%|██████████| 40/40 [00:01<00:00, 33.42it/s]
100%|██████████| 40/40 [00:01<00:00, 36.70it/s]
100%|██████████| 40/40 [00:00<00:00, 40.63it/s]
100%|██████████| 40/40 [00:00<00:00, 43.25it/s]
100%|██████████| 10/10 [00:00<00:00, 40.24it/s]
100%|██████████| 10/10 [00:00<00:00, 37.47it/s]
100%|██████████| 10/10 [00:00<00:00, 39.96it/s]
100%|██████████| 10/10 [00:00<00:00, 38.21it/s]
100%|██████████| 10/10 [00:00<00:00, 39.50it/s]
100%|██████████| 10/10 [00:00<00:00, 42.86it/s]


## Define a training process based on workflows

More usage examples of MONAI workflows are available at: [workflow examples](https://github.com/Project-MONAI/tutorials/tree/master/modules/engines).

In [8]:
def train(index):
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    val_post_transforms = Compose(
        [EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), AsDiscreted(
            keys="pred", threshold_values=True)]
    )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loaders[index],
        network=net,
        inferer=SlidingWindowInferer(
            roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=from_engine(["pred", "label"]),
            )
        },
    )
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=4, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=4,
        train_data_loader=train_loaders[index],
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        amp=False,
        train_handlers=train_handlers,
    )
    trainer.run()
    return net

## Execute 5 training processes and get 5 models

In [9]:
models = [train(i) for i in range(num_models)]

INFO:ignite.engine.engine.SupervisedTrainer:Engine run resuming from iteration 0, epoch 0 until 4 epochs
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 1/20 -- train_loss: 0.6230 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 2/20 -- train_loss: 0.5654 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 3/20 -- train_loss: 0.5949 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 4/20 -- train_loss: 0.5036 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 5/20 -- train_loss: 0.4908 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 6/20 -- train_loss: 0.4712 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 7/20 -- train_loss: 0.4696 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 8/20 -- train_loss: 0.5312 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 9/20 -- train_loss: 0.4865 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 10/20 -- train_loss: 0.

INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 1/20 -- train_loss: 0.5874 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 2/20 -- train_loss: 0.5677 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 3/20 -- train_loss: 0.5438 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 4/20 -- train_loss: 0.5830 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 5/20 -- train_loss: 0.5569 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 6/20 -- train_loss: 0.5162 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 7/20 -- train_loss: 0.5138 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 8/20 -- train_loss: 0.4849 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 9/20 -- train_loss: 0.5582 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 10/20 -- train_loss: 0.4814 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 11/20 -- train_loss: 0.4651 
INFO:ign

INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 2/20 -- train_loss: 0.5730 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 3/20 -- train_loss: 0.5504 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 4/20 -- train_loss: 0.5812 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 5/20 -- train_loss: 0.4541 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 6/20 -- train_loss: 0.4813 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 7/20 -- train_loss: 0.4551 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 8/20 -- train_loss: 0.4609 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 9/20 -- train_loss: 0.5623 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 10/20 -- train_loss: 0.4593 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 11/20 -- train_loss: 0.5020 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 12/20 -- train_loss: 0.4532 
INFO:ig

INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 3/20 -- train_loss: 0.6096 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 4/20 -- train_loss: 0.5008 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 5/20 -- train_loss: 0.5146 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 6/20 -- train_loss: 0.4756 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 7/20 -- train_loss: 0.3944 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 8/20 -- train_loss: 0.4824 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 9/20 -- train_loss: 0.5063 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 10/20 -- train_loss: 0.3977 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 11/20 -- train_loss: 0.4890 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 12/20 -- train_loss: 0.4922 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 13/20 -- train_loss: 0.4344 
INFO:i

INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 4/20 -- train_loss: 0.4934 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 5/20 -- train_loss: 0.5442 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 6/20 -- train_loss: 0.5078 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 7/20 -- train_loss: 0.5374 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 8/20 -- train_loss: 0.5494 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 9/20 -- train_loss: 0.4872 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 10/20 -- train_loss: 0.5015 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 11/20 -- train_loss: 0.5551 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 12/20 -- train_loss: 0.5096 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 13/20 -- train_loss: 0.5215 
INFO:ignite.engine.engine.SupervisedTrainer:Epoch: 1/4, Iter: 14/20 -- train_loss: 0.5300 
INFO:

## Define evaluation process based on `EnsembleEvaluator`

In [10]:
def ensemble_evaluate(post_transforms, models):
    evaluator = EnsembleEvaluator(
        device=device,
        val_data_loader=test_loader,
        pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
        networks=models,
        inferer=SlidingWindowInferer(
            roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        postprocessing=post_transforms,
        key_val_metric={
            "test_mean_dice": MeanDice(
                include_background=True,
                output_transform=from_engine(["pred", "label"]),
            )
        },
    )
    evaluator.run()

## Evaluate the ensemble result with `MeanEnsemble`

`EnsembleEvaluator` accepts a list of models for inference and outputs a list of predictions for further operations.

Here the input data is a list or tuple of PyTorch Tensor with shape: [B, C, H, W, D].  
The list represents the output data from 5 models.  
And `MeanEnsemble` also can support to add `weights` for the input data:
* The `weights` will be added to input data from highest dimension.
* If the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.
* If the `weights` has 3 dimensions, it will be added to `E`, `B` and `C` dimensions.  
For example, to ensemble 3 segmentation model outputs, every output has 4 channels(classes),  
The input data shape can be: [3, B, 4, H, W, D], and add different `weights` for different classes.  
So the `weights` shape can be: [3, 1, 4], like:  
`weights = [[[1, 2, 3, 4]], [[4, 3, 2, 1]], [[1, 1, 1, 1]]]`.

In [11]:
mean_post_transforms = Compose(
    [
        EnsureTyped(keys=["pred0", "pred1", "pred2", "pred3", "pred4"]),
        MeanEnsembled(
            keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
            output_key="pred",
            # in this particular example, we use validation metrics as weights
            weights=[0.95, 0.94, 0.95, 0.94, 0.90],
        ),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ]
)
ensemble_evaluate(mean_post_transforms, models)

INFO:ignite.engine.engine.EnsembleEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs
INFO:ignite.engine.engine.EnsembleEvaluator:Got new best metric of test_mean_dice: 0.9435271978378296
INFO:ignite.engine.engine.EnsembleEvaluator:Epoch[1] Complete. Time taken: 00:00:02
INFO:ignite.engine.engine.EnsembleEvaluator:Engine run complete. Time taken: 00:00:03


## Evaluate the ensemble result with `VoteEnsemble`

Here the input data is a list or tuple of PyTorch Tensor with shape: [B, C, H, W, D].  
The list represents the output data from 5 models.

Note that:
* `VoteEnsemble` expects the input data is discrete values.
* Input data can be multiple channels data in One-Hot format or single channel data.
* It will vote to select the most common data between items.
* The output data has the same shape as every item of the input data.

In [12]:
vote_post_transforms = Compose(
    [
        EnsureTyped(keys=["pred0", "pred1", "pred2", "pred3", "pred4"]),
        Activationsd(keys=["pred0", "pred1", "pred2",
                           "pred3", "pred4"], sigmoid=True),
        # transform data into discrete before voting
        AsDiscreted(keys=["pred0", "pred1", "pred2", "pred3",
                          "pred4"], threshold_values=True),
        VoteEnsembled(keys=["pred0", "pred1", "pred2",
                            "pred3", "pred4"], output_key="pred"),
    ]
)
ensemble_evaluate(vote_post_transforms, models)

INFO:ignite.engine.engine.EnsembleEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs
INFO:ignite.engine.engine.EnsembleEvaluator:Got new best metric of test_mean_dice: 0.9436934590339661
INFO:ignite.engine.engine.EnsembleEvaluator:Epoch[1] Complete. Time taken: 00:00:02
INFO:ignite.engine.engine.EnsembleEvaluator:Engine run complete. Time taken: 00:00:03


## Cleanup data directory

Remove directory if a temporary was used.

In [13]:
if directory is None:
    shutil.rmtree(root_dir)