# Customize Algorithm Ensemble in Auto3DSeg

In this notebook, we will provide a brief example of how to to customize your ensemble pipeline by defining new ensemble class

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"

## Setup imports

In [None]:
import os
import numpy as np
import nibabel as nib
import random

from copy import deepcopy
from pathlib import Path

from monai.apps.auto3dseg import (
    AlgoEnsemble,
    AlgoEnsembleBuilder,
    DataAnalyzer,
    BundleGen,
)

from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils.enums import AlgoEnsembleKeys

## Simulate a dataset and Auto3D datalist using MONAI functions

In [None]:
sim_datalist = {
    "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
    "training": [
        {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
    ],
}

dataroot = str(Path("./data"))
work_dir = str(Path("./ensemble_byoc_work_dir"))

da_output_yaml = os.path.join(work_dir, "datastats.yaml")
data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")

if not os.path.isdir(dataroot):
    os.makedirs(dataroot)

if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

# Generate a fake dataset
for d in sim_datalist["testing"] + sim_datalist["training"]:
    im, seg = create_test_image_3d(64, 64, 64, rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42))
    nib_image = nib.Nifti1Image(im, affine=np.eye(4))
    image_fpath = os.path.join(dataroot, d["image"])
    nib.save(nib_image, image_fpath)

    if "label" in d:
        nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
        label_fpath = os.path.join(dataroot, d["label"])
        nib.save(nib_image, label_fpath)

# write to a json file
sim_datalist_filename = os.path.join(dataroot, "sim_datalist.json")
ConfigParser.export_config_file(sim_datalist, sim_datalist_filename)

## Define a new class that inherit the AlgoEnsemble

In [None]:
class MyAlgoEnsemble(AlgoEnsemble):
    """
    Randomly select N models to do ensemble
    """
    def __init__(self, n_models=3):

        super().__init__()
        self.n_models = n_models

    def collect_algos(self):
        """
        collect_algos defines the method to collect the target algos from the self.algos list
        """
        n = len(self.algos)
        if self.n_models > n:
            raise ValueError(f"Number of loaded Algo is {n}, but {self.n_models} algos are requested.")

        indexes = list(range(n))
        random.shuffle(indexes)
        indexes = indexes[0:self.n_models]
        self.algo_ensemble = []
        for idx in indexes:
            self.algo_ensemble.append(deepcopy(self.algos[idx]))

## Run Auto3DSeg data analyzer, algo generation and training

In [None]:
da = DataAnalyzer(sim_datalist_filename, dataroot, output_path=da_output_yaml)
da.get_all_case_stats()

data_src = {
    "modality": "MRI",
    "datalist": sim_datalist_filename,
    "dataroot": dataroot,
}

ConfigParser.export_config_file(data_src, data_src_cfg)

bundle_generator = BundleGen(
    algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg
)
bundle_generator.generate(work_dir, num_fold=2)
history = bundle_generator.get_history()


max_epochs = 2

# safeguard to ensure max_epochs is greater or equal to 2
max_epochs = max(max_epochs, 2)

train_param = {
    "CUDA_VISIBLE_DEVICES": [0],  # use only 1 gpu
    "num_iterations": 4 * max_epochs,
    "num_iterations_per_validation": 2 * max_epochs,
    "num_images_per_batch": 2,
    "num_epochs": max_epochs,
    "num_warmup_iterations": 2 * max_epochs,
}

for h in history:
    for _, algo in h.items():
        algo.train(train_param)

## Apply MyAlgoEnsemble 

In [None]:
builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(MyAlgoEnsemble())
ensemble = builder.get_ensemble()
preds = ensemble()

print('The ensemble randomly picks the following models:')
for algo in ensemble.get_algo_ensemble():
    print(algo[AlgoEnsembleKeys.ID])