# BYOC for ensemble
## Define a new class that inherit the AlgoEnsemble

In [11]:
from monai.apps.auto3dseg import AlgoEnsemble, AlgoEnsembleBuilder
from monai.utils.enums import AlgoEnsembleKeys
import random
from copy import deepcopy

class MyAlgoEnsemble(AlgoEnsemble):
    """
    Randomly select N models to do ensemble
    """
    def __init__(self, n_models=3):

        super().__init__()
        self.n_models = n_models

    def collect_algos(self):
        """
        collect_algos defines the method to collect the target algos from the self.algos list
        """
        N = len(self.algos)
        if self.n_models > N:
            raise ValueError(f"Number of loaded Algo is {N}, but {self.n_models} algos are requested.")
        
        indexes = list(range(N))
        random.shuffle(indexes)
        indexes = indexes[0:self.n_models]
        self.algo_ensemble = []
        for idx in indexes:
            self.algo_ensemble.append(deepcopy(self.algos[idx]))

## Create a simulated dataset and perform a 2-epoch training

### Import libraries

In [2]:
import os
import numpy as np
import nibabel as nib

from monai.data import create_test_image_3d
from monai.bundle.config_parser import ConfigParser
from monai.apps.auto3dseg import DataAnalyzer, BundleGen

### Create a simulated datalist

In [3]:
sim_datalist = {
    "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
    "training": [
        {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
        {"fold": 2, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
    ],
}

dataroot = os.path.join("./data")
work_dir = os.path.join("./workdir")

da_output_yaml = os.path.join(work_dir, "datastats.yaml")
data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")

if not os.path.isdir(dataroot):
    os.makedirs(dataroot)

if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

# Generate a fake dataset
for d in sim_datalist["testing"] + sim_datalist["training"]:
    im, seg = create_test_image_3d(64, 64, 64, rad_max=10, num_seg_classes=1)
    nib_image = nib.Nifti1Image(im, affine=np.eye(4))
    image_fpath = os.path.join(dataroot, d["image"])
    nib.save(nib_image, image_fpath)

    if "label" in d:
        nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
        label_fpath = os.path.join(dataroot, d["label"])
        nib.save(nib_image, label_fpath)

# write to a json file
sim_datalist_filename = os.path.join(dataroot, "sim_datalist.json")
ConfigParser.export_config_file(sim_datalist, sim_datalist_filename)

In [4]:
### Run Auto3Dseg data analyzer, algo generation and training

In [5]:
da = DataAnalyzer(sim_datalist_filename, dataroot, output_path=da_output_yaml)
da.get_all_case_stats()

data_src = {
    "name": "fake_data",
    "task": "segmentation",
    "modality": "MRI",
    "datalist": sim_datalist_filename,
    "dataroot": dataroot,
    "multigpu": False,
    "class_names": ["label_class"],
}

ConfigParser.export_config_file(data_src, data_src_cfg)

bundle_generator = BundleGen(
    algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg
)
bundle_generator.generate(work_dir, num_fold=2)
history = bundle_generator.get_history()


train_param = {
    "CUDA_VISIBLE_DEVICES": [0],
    "num_iterations": 8,
    "num_iterations_per_validation": 4,
    "num_images_per_batch": 2,
    "num_epochs": 2,
    "num_warmup_iterations": 4,
}

for h in history:
    for _, algo in h.items():
        algo.train(train_param)

File ./workdir/datastats.yaml already exists and will be overwritten.
100%|██████████| 12/12 [00:01<00:00,  9.18it/s]
algo_templates.tar.gz: 304kB [00:02, 153kB/s]                             

2022-09-12 14:24:45,148 - INFO - Downloaded: /tmp/tmpkwh8pbvy/algo_templates.tar.gz
2022-09-12 14:24:45,148 - INFO - Expected md5 is None, skip md5 check for file /tmp/tmpkwh8pbvy/algo_templates.tar.gz.
2022-09-12 14:24:45,149 - INFO - Writing into directory: ./workdir.





2022-09-12 14:24:45,429 - INFO - ./workdir/segresnet2d_0
2022-09-12 14:24:45,625 - INFO - ./workdir/segresnet2d_1
2022-09-12 14:24:45,974 - INFO - ./workdir/dints_0
2022-09-12 14:24:46,191 - INFO - ./workdir/dints_1
2022-09-12 14:24:46,391 - INFO - ./workdir/swinunetr_0
2022-09-12 14:24:46,599 - INFO - ./workdir/swinunetr_1
2022-09-12 14:24:46,794 - INFO - ./workdir/segresnet_0
2022-09-12 14:24:47,010 - INFO - ./workdir/segresnet_1
2022-09-12 14:24:47,012 - INFO - Launching: python ./workdir/segresnet2d_0/scripts/train.py run --config_file='./workdir/segresnet2d_0/configs/transforms_train.yaml','./workdir/segresnet2d_0/configs/network.yaml','./workdir/segresnet2d_0/configs/transforms_validate.yaml','./workdir/segresnet2d_0/configs/hyper_parameters.yaml','./workdir/segresnet2d_0/configs/transforms_infer.yaml' --num_iterations=8 --num_iterations_per_validation=4 --num_images_per_batch=2 --num_epochs=2 --num_warmup_iterations=4
2022-09-12 14:25:01,190 - INFO - CompletedProcess(args=['pyth

In [14]:
builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(MyAlgoEnsemble())
ensemble = builder.get_ensemble()
preds = ensemble()

print('The ensemble randomly picks the following models:')
for algo in ensemble.get_algo_ensemble():
    print(algo[AlgoEnsembleKeys.ID])

0
[info] checkpoint ./workdir/dints_1/model_fold1/best_metric_model.pt loaded
2022-09-12 14:35:07,861 INFO image_writer.py:194 - writing: workdir/dints_1/prediction_testing/val_001.fake/val_001.fake_seg.nii.gz
[info] checkpoint ./workdir/dints_0/model_fold0/best_metric_model.pt loaded
2022-09-12 14:35:09,875 INFO image_writer.py:194 - writing: workdir/dints_0/prediction_testing/val_001.fake/val_001.fake_seg.nii.gz
[info] checkpoint ./workdir/segresnet2d_0/model_fold0/best_metric_model.pt loaded
2022-09-12 14:35:10,698 INFO image_writer.py:194 - writing: workdir/segresnet2d_0/prediction_testing/val_001.fake/val_001.fake_seg.nii.gz
1
[info] checkpoint ./workdir/dints_1/model_fold1/best_metric_model.pt loaded
2022-09-12 14:35:12,818 INFO image_writer.py:194 - writing: workdir/dints_1/prediction_testing/val_002.fake/val_002.fake_seg.nii.gz
[info] checkpoint ./workdir/dints_0/model_fold0/best_metric_model.pt loaded
2022-09-12 14:35:14,654 INFO image_writer.py:194 - writing: workdir/dints_0/