# MONAI Auto3DSeg Hyper-parameter Optimization with NNI

**Auto3DSeg** supports hyper parameter optimization (HPO) with `NNI` and `Optuna` packages.
Please check the [Optuna Notebook](hpo_optuna.ipynb) if you want to use **Auto3DSeg** with `Optuna` HPO.

This notebook provides an example to perform HPO on learning rate with grid search method for hippocampus segmentation using NNI.
To run this notebook, please install `nni` via `pip install` if want to execute HPO with NNI in this tutorial

Note: if you have used other notebooks under `auto3dseg`, for examples: 
- `auto_runner.ipynb`
- `auto3dseg_autorunner_ref_api.ipynb`
- `auto3dseg_hello_world.ipynb`
- `hpo_optuna.ipynb`

You may have already generated the algorithm templates in MONAI bundle formats (hint: find them in the working directory). 

Please feel free to skip step 1-5 if the bundles are already generated.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, nni, tqdm, cucim, yaml, optuna]"

## Import libraries for HPO and pipelines

In [None]:
import os
import torch
import yaml

import tempfile

from monai.apps import download_and_extract
from monai.apps.auto3dseg import BundleGen, DataAnalyzer, NNIGen
from monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history
from monai.bundle.config_parser import ConfigParser

## Download dataset

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

msd_task = "Task04_Hippocampus"
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + msd_task + ".tar"

compressed_file = os.path.join(root_dir, msd_task + ".tar")
dataroot = os.path.join(root_dir, msd_task)
if not os.path.exists(dataroot):
    download_and_extract(resource, compressed_file, root_dir)

datalist_file = os.path.join("..", "tasks", "msd", msd_task, "msd_" + msd_task.lower() + "_folds.json")

# Define experiment file paths

In [None]:
# User created files
nni_yaml = './nni_config.yaml'

# Experiment setup
test_path = "./"
work_dir = os.path.join(test_path, "hpo_nni_work_dir")
datastats_file = os.path.join(work_dir, "datastats.yaml")
if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

## Prepare an input yaml

In [None]:
input_cfg = {
    "name": msd_task,  # optional, it is only for your own record
    "task": "segmentation",  # optional, it is only for your own record
    "modality": "MRI",  # required
    "datalist": datalist_file,  # required
    "dataroot": dataroot,  # required
}
input = './input.yaml'
ConfigParser.export_config_file(input_cfg, input)

## Create Bundle Generators


In [None]:
if not os.path.exists(datastats_file):
    da = DataAnalyzer(datalist_file, dataroot, output_path=datastats_file)
    da.get_all_case_stats()

# algorithm generation
bundle_generator = BundleGen(
    algo_path=work_dir,
    data_stats_filename=datastats_file,
    data_src_cfg_name=input,
)

bundle_generator.generate(work_dir, num_fold=1)
history = bundle_generator.get_history()
export_bundle_algo_history(history)

## 6. Create Algo object from bundle_generator history

In [None]:
# Algorithm selected to do HPO. Refer to bundle history for the mapping between
# algorithm name and index, 0 is SegResNet2D
selected_algorithm_index = 0

# you can get history from bundle_generator. It can also be acquired by reading bundles saved on disk
try:
    history = bundle_generator.get_history()
    assert len(history) > 0
except Exception:
    history = import_bundle_algo_history(work_dir, only_trained=False)

algo_dict = history[selected_algorithm_index]
algo_name = list(algo_dict.keys())[selected_algorithm_index]
algo = algo_dict[algo_name]

In [None]:
# "override_params" is used to update algorithm hyperparameters 
# like num_epochs, which are not in the HPO search space. We set num_epochs=2
# to shorten the training time as an example

max_epochs = 2

# safeguard to ensure max_epochs is greater or equal to 2
max_epochs = max(max_epochs, 2)

num_gpus = 1 if "multigpu" in input_cfg and not input_cfg["multigpu"] else torch.cuda.device_count()

num_epoch = max_epochs
num_images_per_batch = 2
n_data = 24  # total is 30 images, hold out one set (6 images) for cross fold val.
n_iter = int(num_epoch * n_data / num_images_per_batch / num_gpus)
n_iter_val = int(n_iter / 2)

override_param = {
    "num_iterations": n_iter,
    "num_iterations_per_validation": n_iter_val,
    "num_images_per_batch": num_images_per_batch,
    "num_epochs": num_epoch,
    "num_warmup_iterations": n_iter_val,
}

In [None]:
nni_gen = NNIGen(algo=algo, params=override_param)

## 7. Create your NNI configs. Refer to [NNI](https://nni.readthedocs.io/en/stable/) for more details

In [None]:
nni_config = {
    "experimentName": msd_task + "_lr",
    "searchSpace": {
        "learning_rate": {
            "_type": "choice",
            "_value": [0.0001, 0.001, 0.01, 0.1]
        }
    },
    "trialCommand": None,
    "trialCodeDirectory": ".",
    "trialGpuNumber": 1,
    "trialConcurrency": 2,
    "maxTrialNumber": 10,
    "maxExperimentDuration": "1h",
    "tuner": {"name": "GridSearch"},
    "trainingService": {
        "platform": "local", "useActiveGpu": True}
}
with open(nni_yaml, 'w') as f:
    yaml.dump(nni_config, f)

## 8. Run NNI from terminal
### Step 1: copy the trialCommand print out info, e.g.
```
python -m monai.apps.auto3dseg NNIGen run_algo  ./workdir/segresnet2d_0/algo_object.pkl {result_dir}
```
Replace {result_dir} with a folder path to save HPO experiments.
### Step 2: copy the above trialCommand to replace the trialCommand in nni_config.yaml
### Step 3: run NNI experiemtns from a terminal with 
```
nnictl create --config ./nni_config.yaml
```

Use the print out trialCommand from NNIGen initialization to replace the trialCommand in nni_config and run NNI from terminal

## 9. Example Results
We changed override_param to {'num_iterations':6000, 'num_iterations_per_validation':600}, to run the experiments for longer time.
Here is the results shown in NNI webui. The optimal learning rate for SegResNet2D (selected_algorithm_index=0) is 0.1, which achieves Dice score of 0.735.

![](../figures/nni_image0.png)
![](../figures/nni_image1.png)
