Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# nnUNet MONAI Bundle

This notebook demonstrates how to create a MONAI Bundle for a trained nnUNet and use it for inference. This is needed when some other application from the MONAI EcoSystem require a MONAI Bundle (MONAI Label, MonaiAlgo for Federated Learning, etc).

This notebook cover the steps to convert a trained nnUNet model to a consumable MONAI Bundle. The nnUNet training is here perfomed using the `nnUNetV2Runner`.

Optionally, the notebook also demonstrates how to use the same nnUNet MONAI Bundle for training a new model. This might be needed in some applications where the nnUNet training needs to be performed through a MONAI Bundle (i.e., Active Learning in MONAI Label, MonaiAlgo for Federated Learning, etc).

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
!python -c "import nnunetv2" || pip install -q nnunetv2

## Setup imports

In [None]:
from monai.config import print_config
import os
import tempfile
from monai.bundle.config_parser import ConfigParser
from monai.apps.nnunet import nnUNetV2Runner
#from monai.bundle.nnunet import convert_nnunet_to_monai_bundle
import nnunetv2
print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Download Decathlon Spleen Dataset and Generate Data List

To get the Decathlon Spleen dataset and generate the corresponding data list, you can follow the instructions in the [MSD Datalist Generator Notebook](../auto3dseg/notebooks/msd_datalist_generator.ipynb)

At the end of the notebook, remember to copy the generated `msd_task09_spleen_folds.json` file to the `<root_dir>/Task09_Spleen` directory.

## nnUNet Experiment with nnUNetV2Runner

In the following sections, we will use the nnUNetV2Runner to train a model on the spleen dataset from the Medical Segmentation Decathlon.

We first create the Config file for the nnUNetV2Runner:

In [None]:
nnunet_root_dir = os.path.join(root_dir, "nnUNet")

os.makedirs(nnunet_root_dir, exist_ok=True)

data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml")
data_src = {"modality": "CT", "datalist": os.path.join(root_dir,"Task09_Spleen/msd_task09_spleen_folds.json"), "dataroot": os.path.join(root_dir,"Task09_Spleen")}

ConfigParser.export_config_file(data_src, data_src_cfg)


In [None]:
runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name="nnUNetTrainer_10epochs", work_dir=nnunet_root_dir)

In [None]:
runner.plan_and_process(npfp=2,n_proc=[2,2,2])

In [None]:
runner.train(configs="3d_fullres")

In [None]:
runner.run(run_train=True, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)

## nnUNet MONAI Bundle for Inference

This section is the relevant part of the nnUNet MONAI Bundle for Inference, showing how to use the trained model to perform inference on new data through the use of a MONAI Bundle, wrapping the native nnUNet model and its pre- and post-processing steps.

We first create the MONAI Bundle for the nnUNet model:

In [None]:
%%bash

rm nnUNetBundle/configs/inference.json
python -m monai.bundle init_bundle nnUNetBundle

mkdir -p nnUNetBundle/src
touch nnUNetBundle/src/__init__.py
which tree && tree nnUNetBundle || true

We then populate the MONAI Bundle with the configuration for inference:

In [None]:
%%writefile nnUNetBundle/configs/inference.yaml

imports: 
  - $import json
  - $from pathlib import Path
  - $import os
  - $import monai.bundle.nnunet
  - $from ignite.contrib.handlers.tqdm_logger import ProgressBar
  - $import shutil


output_dir: "."
bundle_root: "."
data_list_file : "."
data_dir: "."

prediction_suffix: "prediction"

test_data_list: "$monai.data.load_decathlon_datalist(@data_list_file, is_segmentation=True, data_list_key='testing', base_dir=@data_dir)"
image_modality_keys: "$list(@modality_conf.keys())"
image_key: "image"
image_suffix: "@image_key"

preprocessing:
  _target_: Compose
  transforms:
  - _target_: LoadImaged
    keys: "image"
    ensure_channel_first: True
    image_only: False

test_dataset:
  _target_: Dataset
  data: "$@test_data_list"
  transform: "@preprocessing"

test_loader:
  _target_: DataLoader
  dataset: "@test_dataset"
  batch_size: 1


device: "$torch.device('cuda')"

nnunet_config:
  model_folder: "$@bundle_root + '/models'"

network_def: "$monai.bundle.nnunet.get_nnunet_monai_predictor(**@nnunet_config)"

postprocessing:
  _target_: "Compose"
  transforms:
    - _target_: Transposed
      keys: "pred"
      indices:
      - 0
      - 3
      - 2
      - 1
    - _target_: SaveImaged
      keys: "pred"
      resample: False
      output_postfix: "@prediction_suffix"
      output_dir: "@output_dir"
      meta_keys: "image_meta_dict"


testing:
  dataloader: "$@test_loader"
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  test_inferer: "$@inferer"

inferer: 
  _target_: "SimpleInferer"

validator:
  _target_: "SupervisedEvaluator"
  postprocessing: "$@postprocessing"
  device: "$@device"
  inferer: "$@testing#test_inferer"
  val_data_loader: "$@testing#dataloader"
  network: "@network_def"
  #prepare_batch: "$src.inferer.prepare_nnunet_inference_batch"
  val_handlers:
  - _target_: "CheckpointLoader"
    load_path: "$@bundle_root+'/models/model.pt'"
    load_dict:
      network_weights: '$@network_def.network_weights'
run:
  - "$@testing#pbar.attach(@validator)"
  - "$@validator.run()"

### nnUnet to MONAI Bundle Conversion

Finally, we convert the nnUNet Trained Model to a Bundle-compatible format using the `convert_nnunet_to_monai_bundle` function:

In [None]:
nnunet_config = {
                "dataset_name_or_id": "001",
                "nnunet_trainer": "nnUNetTrainer_1epoch",
}

bundle_root = "nnUNetBundle"

convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)

You can then inspect the content of the `models` folder to verify that the model has been converted to the MONAI Bundle format.

In [None]:
%%bash

which tree && tree nnUNetBundle/models

### Test the MONAI Bundle for Inference

The MONAI Bundle for Inference is now ready to be used for inference on new data

In [None]:
%%bash

python -m monai.bundle run \
    --config-file nnUNetBundle/configs/inference.yaml \
    --bundle-root nnUNetBundle \
    --data_list_file  $MONAI_DATA_DIRECTORY/Task09_Spleen/msd_task09_spleen_folds.json \
    --output-dir nnUNetBundle/pred_output \
    --data_dir /home/maia-user/Tutorials/MONAI/data/Task09_Spleen \
    --logging-file nnUNetBundle/configs/logging.conf

## Optional: Training nnUNet from the MONAI Bundle

In some cases, you may want to train the nnUNet model from the MONAI Bundle (i.e., without using the nnUNetV2Runner).
This is usually the case when the specific training logic is designed to be used with the MONAI Bundle, such as the Active Learning in MONAI Label or Federated Learning in NVFLare using the MONAI Algo implementation.

This can be done by following the steps below:

In [None]:
%%writefile nnUNetBundle/configs/train.yaml

imports:
  - $import json
  - $import os
  - $import nnunetv2
  - $import src
  - $import src.nnunet_batch_preparation
  - $import monai.bundle.nnunet
  - $import shutil
  - $import pathlib


pymaia_config_dict: "$json.load(open(@pymaia_config_file))"
bundle_root: .
ckpt_dir: "$@bundle_root + '/models'"
num_classes: 2

nnunet_configuration: "3d_fullres"
dataset_name_or_id: "001"
fold: "0"
trainer_class_name: "nnUNetTrainer"
plans_identifier: "nnUNetPlans"

dataset_name: "$nnunetv2.utilities.dataset_name_id_conversion.maybe_convert_to_dataset_name(@dataset_name_or_id)"
nnunet_model_folder: "$os.path.join(os.environ['nnUNet_results'], @dataset_name, @trainer_class_name+'__'+@plans_identifier+'__'+@nnunet_configuration)"

nnunet_config:
  dataset_name_or_id: "@dataset_name_or_id"
  configuration: "@nnunet_configuration"
  trainer_class_name: "@trainer_class_name"
  plans_identifier: "@plans_identifier"
  fold: "@fold"


nnunet_trainer: "$monai.bundle.nnunet.get_nnunet_trainer(**@nnunet_config)"

iterations: $@nnunet_trainer.num_iterations_per_epoch
device: $@nnunet_trainer.device
epochs: $@nnunet_trainer.num_epochs

loss: $@nnunet_trainer.loss
lr_scheduler: $@nnunet_trainer.lr_scheduler

network_def: $@nnunet_trainer_def.network
network: $@nnunet_trainer.network

optimizer: $@nnunet_trainer.optimizer


checkpoint:
  init_args: '$@nnunet_trainer.my_init_kwargs'
  trainer_name: '$@nnunet_trainer.__class__.__name__'
  inference_allowed_mirroring_axes: '$@nnunet_trainer.inference_allowed_mirroring_axes'

checkpoint_filename: "$@bundle_root+'/models/nnunet_checkpoint.pth'"
output_dir: $@bundle_root + '/logs'

train:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  dataloader: $@nnunet_trainer.dataloader_train
  train_data: "$[{'case_identifier':k} for k in @nnunet_trainer.dataloader_train.generator._data.dataset.keys()]"
  train_dataset:
    _target_: Dataset
    data: "@train#train_data"
  handlers:
  - _target_: LrScheduleHandler
    lr_scheduler: '@lr_scheduler'
    print_lr: true
  - _target_: ValidationHandler
    epoch_level: true
    interval: '@val_interval'
    validator: '@validate#evaluator'
  #- _target_: StatsHandler
  #  output_transform: $monai.handlers.from_engine(['loss'], first=True)
  #  tag_name: train_loss
  - _target_: TensorBoardStatsHandler
    log_dir: '@output_dir'
    output_transform: $monai.handlers.from_engine(['loss'], first=True)
    tag_name: train_loss
  inferer:
    _target_: SimpleInferer
  key_metric:
    Train_Dice:
      _target_: "MeanDice"
      include_background: False
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean"
  additional_metrics:
    Train_Dice_per_class:
      _target_: "MeanDice"
      include_background: False
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean_batch"
  postprocessing:
    _target_: "Compose"
    transforms:
    - _target_: Lambdad
      keys:
        - "pred"
        - "label"
      func: "$lambda x: x[0]"
    - _target_: Activationsd
      keys:
        - "pred"
      softmax: True
    - _target_: AsDiscreted
      keys:
       - "pred"
      threshold: 0.5
    - _target_: AsDiscreted
      keys:
        - "label"
      to_onehot: "@num_classes"
  postprocessing_region_based:
    _target_: "Compose"
    transforms:
    - _target_: Lambdad
      keys:
        - "pred"
        - "label"
      func: "$lambda x: x[0]"
    - _target_: Activationsd
      keys:
        - "pred"
      sigmoid: True
    - _target_: AsDiscreted
      keys:
       - "pred"
      threshold: 0.5
  trainer:
    _target_: SupervisedTrainer
    amp: true
    device: '@device'
    additional_metrics: "@train#additional_metrics"
    epoch_length: "@iterations"
    inferer: '@train#inferer'
    key_train_metric: '@train#key_metric'
    loss_function: '@loss'
    max_epochs: '@epochs'
    network: '@network'
    prepare_batch: "$src.nnunet_batch_preparation.prepare_nnunet_batch"
    optimizer: '@optimizer'
    postprocessing: '@train#postprocessing'
    train_data_loader: '@train#dataloader'
    train_handlers: '@train#handlers'

val_interval: 1
validate:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  key_metric:
    Val_Dice:
      _target_: "MeanDice"
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean"
      include_background: False
  additional_metrics:
    Val_Dice_per_class:
      _target_: "MeanDice"
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean_batch"
      include_background: False
  dataloader: $@nnunet_trainer.dataloader_val
  evaluator:
    _target_: SupervisedEvaluator
    additional_metrics: '@validate#additional_metrics'
    amp: true
    epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch
    device: '@device'
    inferer: '@validate#inferer'
    key_val_metric: '@validate#key_metric'
    network: '@network'
    postprocessing: '@validate#postprocessing'
    val_data_loader: '@validate#dataloader'
    val_handlers: '@validate#handlers'
    prepare_batch: "$src.nnunet_batch_preparation.prepare_nnunet_batch"
  handlers:
  - _target_: StatsHandler
    iteration_log: false
  - _target_: TensorBoardStatsHandler
    iteration_log: false
    log_dir: '@output_dir'
  - _target_: "CheckpointSaver"
    save_dir: "$str(@bundle_root)+'/models'"
    save_interval: 1
    n_saved: 1
    save_key_metric: true
    save_dict:
      network_weights: '$@nnunet_trainer.network._orig_mod'
      optimizer_state: '$@nnunet_trainer.optimizer'
      scheduler: '$@nnunet_trainer.lr_scheduler'
  inferer:
    _target_: SimpleInferer
  postprocessing: '%train#postprocessing'

run:
- "$torch.save(@checkpoint,@checkpoint_filename)"
- "$shutil.copy(pathlib.Path(@nnunet_model_folder).joinpath('dataset.json'), @bundle_root+'/models/dataset.json')"
- "$shutil.copy(pathlib.Path(@nnunet_model_folder).joinpath('plans.json'), @bundle_root+'/models/plans.json')"
- "$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})"
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@train#trainer.run()

initialize:
- $monai.utils.set_determinism(seed=123)

Additionally, we create the Python function to prepare the batch from the nnUNet DataLoader:

In [None]:
%%writefile nnUNetBundle/src/nnunet_batch_preparation.py

def prepare_nnunet_batch(batch, device, non_blocking):
    data = batch["data"].to(device, non_blocking=non_blocking)
    if isinstance(batch["target"], list):
        target = [i.to(device, non_blocking=non_blocking) for i in batch["target"]]
    else:
        target = batch["target"].to(device, non_blocking=non_blocking)
    return data, target

Finally, since the original nnUNet Scheduler implementation is not compatible with a MONAI Bundle training, we will create a custom PolyLRScheduler class that can be used in the nnUNet training, overriding the original implementation.

The incompatibility is derived from the missing `get_last_lr` method in the original implementation, which is used to log the learning rate in the MONAI Bundle training.

In [None]:
import nnunetv2
print(nnunetv2.__file__)

Overwrite the original PolyLRScheduler class with the custom implementation:

In [None]:
%%writefile </path/to/nnunetv2>/training/lr_scheduler/polylr.py

from torch.optim.lr_scheduler import _LRScheduler


class PolyLRScheduler(_LRScheduler):
    def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.max_steps = max_steps
        self.exponent = exponent
        self.ctr = 0
        super().__init__(optimizer, current_step if current_step is not None else -1, False)

    def step(self, current_step=None):
        if current_step is None or current_step == -1:
            current_step = self.ctr
            self.ctr += 1

        new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def get_last_lr(self):
        return self._last_lr

We can now train the nnUNet model using the MONAI Bundle:

In [None]:
%%bash

export nnUNet_raw=$MONAI_DATA_DIRECTORY"/nnUNet/nnUNet_raw_data_base"
export nnUNet_preprocessed=$MONAI_DATA_DIRECTORY"/nnUNet/nnUNet_preprocessed"
export nnUNet_results=$MONAI_DATA_DIRECTORY"/nnUNet/nnUNet_trained_models"

export BUNDLE=nnUNetBundle
export PYTHONPATH=$BUNDLE

#export nnUNet_def_n_proc=2
#export nnUNet_n_proc_DA=2

python -m monai.bundle run \
--bundle-root nnUNetBundle \
--config-file nnUNetBundle/configs/train.yaml

You can follow the training progress with TensorBoard by running the following command in a new terminal:

```bash
tensorboard --logdir nnUNetBundle/logs
```