Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License. 

# Customize Algorithm Ensemble in Auto3DSeg

In this notebook, we will provide a brief example of how to customize your ensemble pipeline by defining a new ensemble class.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, yaml, tqdm]"

## Setup imports

In [None]:
import os
import json
import numpy as np
import nibabel as nib
import random

from copy import deepcopy
from pathlib import Path

from monai.apps.auto3dseg import (
    AlgoEnsemble,
    AlgoEnsembleBuilder,
    DataAnalyzer,
    BundleGen,
)

from monai.bundle.config_parser import ConfigParser
from monai.config import print_config
from monai.data import create_test_image_3d
from monai.utils.enums import AlgoEnsembleKeys

print_config()

## Simulate a special dataset

It is well known that AI takes time to train. To provide the "Hello World!" experience of Auto3D in this notebook, we will simulate a small dataset and run training only for multiple epochs. Due to the nature of AI, the performance shouldn't be highly expected, but the entire pipeline will be completed within minutes!

`sim_datalist` provides the information of the simulated datasets. It lists 12 training and 2 testing images and labels.
The training data are split into 3 folds. Each fold will use 8 images to train and 4 images to validate.
The size of the dimension is defined by the `sim_dim`.

In [None]:
sim_datalist = {
    "testing": [
        {"image": "test_image_001.nii.gz", "label": "test_label_001.nii.gz"},
        {"image": "test_image_002.nii.gz", "label": "test_label_002.nii.gz"},
    ],
    "training": [
        {"fold": 0, "image": "tr_image_001.nii.gz", "label": "tr_label_001.nii.gz"},
        {"fold": 0, "image": "tr_image_002.nii.gz", "label": "tr_label_002.nii.gz"},
        {"fold": 0, "image": "tr_image_003.nii.gz", "label": "tr_label_003.nii.gz"},
        {"fold": 0, "image": "tr_image_004.nii.gz", "label": "tr_label_004.nii.gz"},
        {"fold": 1, "image": "tr_image_005.nii.gz", "label": "tr_label_005.nii.gz"},
        {"fold": 1, "image": "tr_image_006.nii.gz", "label": "tr_label_006.nii.gz"},
        {"fold": 1, "image": "tr_image_007.nii.gz", "label": "tr_label_007.nii.gz"},
        {"fold": 1, "image": "tr_image_008.nii.gz", "label": "tr_label_008.nii.gz"},
        {"fold": 2, "image": "tr_image_009.nii.gz", "label": "tr_label_009.nii.gz"},
        {"fold": 2, "image": "tr_image_010.nii.gz", "label": "tr_label_010.nii.gz"},
        {"fold": 2, "image": "tr_image_011.nii.gz", "label": "tr_label_011.nii.gz"},
        {"fold": 2, "image": "tr_image_012.nii.gz", "label": "tr_label_012.nii.gz"},
    ],
}

sim_dim = (64, 64, 64)

## Generate images and labels

Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the work_dir

In [None]:
work_dir = str(Path("./ensemble_byoc_work_dir"))
if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

dataroot_dir = os.path.join(work_dir, "sim_dataroot")
if not os.path.isdir(dataroot_dir):
    os.makedirs(dataroot_dir)

datalist_file = os.path.join(work_dir, "sim_datalist.json")
with open(datalist_file, "w") as f:
    json.dump(sim_datalist, f)

for d in sim_datalist["testing"] + sim_datalist["training"]:
    im, seg = create_test_image_3d(
        sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42)
    )
    image_fpath = os.path.join(dataroot_dir, d["image"])
    label_fpath = os.path.join(dataroot_dir, d["label"])
    nib.save(nib.Nifti1Image(im, affine=np.eye(4)), image_fpath)
    nib.save(nib.Nifti1Image(seg, affine=np.eye(4)), label_fpath)

## Define a new class that inherit the AlgoEnsemble

In [None]:
class MyAlgoEnsemble(AlgoEnsemble):
    """
    Randomly select N models to do ensemble
    """

    def __init__(self, n_models=3):

        super().__init__()
        self.n_models = n_models

    def collect_algos(self):
        """
        collect_algos defines the method to collect the target algos from the self.algos list
        """
        n = len(self.algos)
        if self.n_models > n:
            raise ValueError(f"Number of loaded Algo is {n}, but {self.n_models} algos are requested.")

        indexes = list(range(n))
        random.shuffle(indexes)
        indexes = indexes[0 : self.n_models]
        self.algo_ensemble = []
        for idx in indexes:
            self.algo_ensemble.append(deepcopy(self.algos[idx]))

## Run Auto3DSeg data analyzer, algo generation, and training

> NOTE: For demo purposes, below contains a snippet to convert num_epoch to iteration style and override all algorithms with the same training parameters `train_params`.
> If users would like to use more than one GPU, they can change the `CUDA_VISIBLE_DEVICES`, or just remove the key to use all available devices.
> Users also need to ensure the number of GPUs is not greater than the number that the training dataset can be partitioned.

In [None]:
da = DataAnalyzer(datalist_file, dataroot_dir)
da.get_all_case_stats()

input = {
    "modality": "MRI",
    "datalist": datalist_file,
    "dataroot": dataroot_dir,
}

input_cfg = "input.yaml"
ConfigParser.export_config_file(input, input_cfg)

bundle_generator = BundleGen(algo_path=work_dir, data_stats_filename="data_stats.yaml", data_src_cfg_name=input_cfg)
bundle_generator.generate(work_dir, num_fold=2)
history = bundle_generator.get_history()

max_epochs = 2

train_param = {
    "CUDA_VISIBLE_DEVICES": [0],  # use only 1 gpu
    "num_epochs_per_validation": 1,
    "num_images_per_batch": 2,
    "num_epochs": max_epochs,
    "num_warmup_epochs": 1,
}

for h in history:
    for _, algo in h.items():
        algo.train(train_param)

## Apply MyAlgoEnsemble

As we defined earlier in this notebook, `MyAlgoEnsemble` randomly shuffles the trained models, and picks three for ensemble.

In [None]:
builder = AlgoEnsembleBuilder(history, input_cfg)
builder.set_ensemble_method(MyAlgoEnsemble())
ensemble = builder.get_ensemble()
preds = ensemble()

print("The ensemble randomly picks the following models:")
for algo in ensemble.get_algo_ensemble():
    print(algo[AlgoEnsembleKeys.ID])