# Models ensemble to achieve better test metrics
Models ensemble is a popular strategy in machine learning and deep learning areas to achieve more accurate and more stable outputs.  
A typical practice is:
- Split all the training dataset into K folds.
- Train K models with every K-1 folds data.
- Execute inference on the test data with all the K models.
- Compute the average values with weights or vote the most common value as the final result.
<p>
<img src="./images/models_ensemble.png" width="80%" alt='models_ensemble'>
</p>

MONAI provides `EnsembleEvaluator` and `MeanEnsemble`, `VoteEnsemble` post transforms.  
This tutorial shows how to leverage ensemble modules in MONAI to set up ensemble program.

In [1]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
import shutil
from glob import glob
import logging
import nibabel as nib
import numpy as np
import torch

import monai
from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, \
    RandCropByPosNegLabeld, RandRotate90d, ToTensord, Activationsd, AsDiscreted, \
    MeanEnsembled, VoteEnsembled
from monai.handlers import StatsHandler, MeanDice
from monai.data import create_test_image_3d
from monai.engines import SupervisedTrainer, SupervisedEvaluator, EnsembleEvaluator
from monai.inferers import SimpleInferer, SlidingWindowInferer

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

## Generate random (image, label) pairs
Generate 60 pairs for the task, 50 for training and 10 for test.  
And then split the 50 pairs into 5 folds to train 5 separate models.

In [2]:
for i in range(60):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join("./runs", f"img{i:d}.nii.gz"))
    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join("./runs", f"seg{i:d}.nii.gz"))

images = sorted(glob(os.path.join("./runs", "img*.nii.gz")))
segs = sorted(glob(os.path.join("./runs", "seg*.nii.gz")))

train_files = list()
val_files = list()
for i in range(5):
    train_files[i] = [{"image": img, "label": seg}
                      for img, seg in
                      zip(images[: (10 * i)] + images[(10 * (i + 1)) : 50],
                          segs[: (10 * i)] + segs[(10 * (i + 1)) : 50])
                     ]
    val_files[i] = [{"image": img, "label": seg}
                    for img, seg in
                    zip(images[(10 * i) : (10 * (i + 1))], segs[(10 * i) : (10 * (i + 1))])
                   ]

test_files = [{"image": img, "label": seg} for img, seg in zip(images[50: 60], segs[50 : 60])]

## Setup transforms for training and validation

In [4]:
train_transforms = Compose(
    [
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(
            keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
        ),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["image", "label"])
    ]
)
val_transforms = Compose(
    [
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"])
    ]
)

## Define CacheDatasets and DataLoaders for training and validation

In [5]:
train_dss = [monai.data.CacheDataset(data=train_files[i], transform=train_transforms) for i in range(5)]
train_loaders = [monai.data.DataLoader(train_dss[i], batch_size=2, shuffle=True, num_workers=4) for i in range(5)]

val_dss = [monai.data.CacheDataset(data=val_files, transform=val_transforms) for i in range(5)]
val_loaders = [monai.data.DataLoader(val_ds[i], batch_size=1, num_workers=4) for i in range(5)]

test_ds = monai.data.CacheDataset(data=test_files, transform=val_transforms)
test_loader = monai.data.DataLoader(test_ds, batch_size=1, num_workers=4)



## Define a typical PyTorch training process

In [None]:
def train(index):
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256),
                                   strides=(2, 2, 2, 2), num_res_units=2).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True)
        ]
    )
    val_handlers = [StatsHandler(output_transform=lambda x: None)]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loaders[index],
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers
    )
    train_handlers = [
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"])
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=2,
        train_data_loader=train_loader[index],
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        amp=False,
        train_handlers=train_handlers
    )
    trainer.run()
return net

## Execute 5 training processes and get 5 models

In [None]:
models = [train(i) for i in range(5)]

## Separately evaluate the 5 models on the test data

In [None]:
def evaluate(model):
    test_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True)
        ]
    )
    test_handlers = [StatsHandler(output_transform=lambda x: None)]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=test_loader,
        network=model,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=test_post_transforms,
        key_val_metric={
            "test_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=test_handlers
    )
    evaluator.run()

for model in models:
    evaluate(model)

## Evaluate the ensemble result with `MeanEnsemble`

In [None]:
def ensemble_evaluate(post_transforms, models):
    test_handlers = [StatsHandler(output_transform=lambda x: None)]

    evaluator = EnsembleEvaluator(
        device=device,
        val_data_loader=test_loader,
        pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
        networks=models,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=post_transforms,
        key_val_metric={
            "test_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=test_handlers
    )
    evaluator.run()

mean_post_transforms = Compose(
    [
        MeanEnsemble(
            keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
            output_key="pred",
            weights=[1 / 1, 1 / 1, 1 / 1, 1 / , 1 / 1]
        ),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True)
    ]
)
ensemble_evaluate(mean_post_transforms, models)

## Evaluate the ensemble result with `VoteEnsemble`

In [None]:
vote_post_transforms = Compose(
    [
        Activationsd(keys=["pred0", "pred1", "pred2", "pred3", "pred4"], sigmoid=True),
        AsDiscreted(keys=["pred0", "pred1", "pred2", "pred3", "pred4"], threshold_values=True),
        VoteEnsemble(
            keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
            output_key="pred",
            weights=[1 / 1, 1 / 1, 1 / 1, 1 / , 1 / 1]
        )
    ]
)
ensemble_evaluate(vote_post_transforms, models)