# Customize Algorithm Ensemble

In this notebook, we will provide a brief example of how to to customize your ensemble pipeline by defining new ensemble class

## 1 Set up environment, imports and datasets
### 1.1 Set up Environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"

### 1.2 Set up imports

In [2]:
import os
import numpy as np
import nibabel as nib
import random

from copy import deepcopy
from monai.apps.auto3dseg import (
    AlgoEnsemble,
    AlgoEnsembleBuilder,
    DataAnalyzer,
    BundleGen,
)

from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils.enums import AlgoEnsembleKeys

  from .autonotebook import tqdm as notebook_tqdm


### 1.3 Simulate a dataset and Auto3D datalist using MONAI functions

In [3]:
sim_datalist = {
    "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
    "training": [
        {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
    ],
}

dataroot = os.path.join("./data")
work_dir = os.path.join("./workdir")

da_output_yaml = os.path.join(work_dir, "datastats.yaml")
data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")

if not os.path.isdir(dataroot):
    os.makedirs(dataroot)

if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

# Generate a fake dataset
for d in sim_datalist["testing"] + sim_datalist["training"]:
    im, seg = create_test_image_3d(64, 64, 64, rad_max=10, num_seg_classes=1)
    nib_image = nib.Nifti1Image(im, affine=np.eye(4))
    image_fpath = os.path.join(dataroot, d["image"])
    nib.save(nib_image, image_fpath)

    if "label" in d:
        nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
        label_fpath = os.path.join(dataroot, d["label"])
        nib.save(nib_image, label_fpath)

# write to a json file
sim_datalist_filename = os.path.join(dataroot, "sim_datalist.json")
ConfigParser.export_config_file(sim_datalist, sim_datalist_filename)

## 2 Define a new class that inherit the AlgoEnsemble

In [4]:
class MyAlgoEnsemble(AlgoEnsemble):
    """
    Randomly select N models to do ensemble
    """
    def __init__(self, n_models=3):

        super().__init__()
        self.n_models = n_models

    def collect_algos(self):
        """
        collect_algos defines the method to collect the target algos from the self.algos list
        """
        n = len(self.algos)
        if self.n_models > n:
            raise ValueError(f"Number of loaded Algo is {n}, but {self.n_models} algos are requested.")

        indexes = list(range(n))
        random.shuffle(indexes)
        indexes = indexes[0:self.n_models]
        self.algo_ensemble = []
        for idx in indexes:
            self.algo_ensemble.append(deepcopy(self.algos[idx]))

### 2.1 Run Auto3Dseg data analyzer, algo generation and training

In [5]:
da = DataAnalyzer(sim_datalist_filename, dataroot, output_path=da_output_yaml)
da.get_all_case_stats()

data_src = {
    "modality": "MRI",
    "datalist": sim_datalist_filename,
    "dataroot": dataroot,
    "class_names": ["label_class"],
}

ConfigParser.export_config_file(data_src, data_src_cfg)

bundle_generator = BundleGen(
    algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg
)
bundle_generator.generate(work_dir, num_fold=2)
history = bundle_generator.get_history()


train_param = {
    "CUDA_VISIBLE_DEVICES": [0],
    "num_iterations": 8,
    "num_iterations_per_validation": 4,
    "num_images_per_batch": 2,
    "num_epochs": 2,
    "num_warmup_iterations": 4,
}

for h in history:
    for _, algo in h.items():
        algo.train(train_param)

100%|██████████| 12/12 [00:01<00:00,  9.60it/s]
algo_templates.tar.gz: 100%|██████████| 280k/280k [00:01<00:00, 172kB/s]  

2022-09-15 18:01:04,601 - INFO - Downloaded: /tmp/tmp55bhww8u/algo_templates.tar.gz
2022-09-15 18:01:04,602 - INFO - Expected md5 is None, skip md5 check for file /tmp/tmp55bhww8u/algo_templates.tar.gz.
2022-09-15 18:01:04,603 - INFO - Writing into directory: ./workdir.





2022-09-15 18:01:04,892 - INFO - ./workdir/segresnet2d_0
2022-09-15 18:01:05,222 - INFO - ./workdir/segresnet2d_1
2022-09-15 18:01:05,450 - INFO - ./workdir/dints_0
2022-09-15 18:01:05,671 - INFO - ./workdir/dints_1
2022-09-15 18:01:05,897 - INFO - ./workdir/swinunetr_0
2022-09-15 18:01:06,145 - INFO - ./workdir/swinunetr_1
2022-09-15 18:01:06,376 - INFO - ./workdir/segresnet_0
2022-09-15 18:01:06,609 - INFO - ./workdir/segresnet_1
2022-09-15 18:01:06,610 - INFO - Launching: python ./workdir/segresnet2d_0/scripts/train.py run --config_file='./workdir/segresnet2d_0/configs/transforms_train.yaml','./workdir/segresnet2d_0/configs/network.yaml','./workdir/segresnet2d_0/configs/transforms_validate.yaml','./workdir/segresnet2d_0/configs/hyper_parameters.yaml','./workdir/segresnet2d_0/configs/transforms_infer.yaml' --num_iterations=8 --num_iterations_per_validation=4 --num_images_per_batch=2 --num_epochs=2 --num_warmup_iterations=4
2022-09-15 18:01:20,696 - INFO - CompletedProcess(args=['pyth

### 2.2 Apply MyAlgoEnsemble 

In [6]:
builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(MyAlgoEnsemble())
ensemble = builder.get_ensemble()
preds = ensemble()

print('The ensemble randomly picks the following models:')
for algo in ensemble.get_algo_ensemble():
    print(algo[AlgoEnsembleKeys.ID])

0
[info] checkpoint ./workdir/segresnet_0/model_fold0/best_metric_model.pt loaded
2022-09-15 18:05:59,870 INFO image_writer.py:194 - writing: workdir/segresnet_0/prediction_testing/val_001.fake/val_001.fake_seg.nii.gz
[info] checkpoint ./workdir/swinunetr_0/model_fold0/best_metric_model.pt loaded
2022-09-15 18:06:00,831 INFO image_writer.py:194 - writing: workdir/swinunetr_0/prediction_testing/val_001.fake/val_001.fake_seg.nii.gz
[info] checkpoint ./workdir/dints_0/model_fold0/best_metric_model.pt loaded
2022-09-15 18:06:03,386 INFO image_writer.py:194 - writing: workdir/dints_0/prediction_testing/val_001.fake/val_001.fake_seg.nii.gz
1
[info] checkpoint ./workdir/segresnet_0/model_fold0/best_metric_model.pt loaded
2022-09-15 18:06:04,433 INFO image_writer.py:194 - writing: workdir/segresnet_0/prediction_testing/val_002.fake/val_002.fake_seg.nii.gz
[info] checkpoint ./workdir/swinunetr_0/model_fold0/best_metric_model.pt loaded
2022-09-15 18:06:05,261 INFO image_writer.py:194 - writing: 