## HPO tutorial: learning rate grid search on prostate segmentation using NNI. 
#### Skip step 1-5 if the bundles are already generated.

### 1. Import libraries for HPO and pipelines

In [None]:
import json
import os
import optuna
import yaml

from functools import partial

from monai.apps import download_and_extract
from monai.apps.auto3dseg import BundleGen, DataAnalyzer, OptunaGen
from monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history
from monai.bundle.config_parser import ConfigParser

### 2. Define experiment file pathes

In [None]:
# User created files
datalist_filename = './msd_task05_prostate_folds.json'
input_yaml = './input.yaml'

# Dataset pathes
data_root = "./"
msd_task = "Task05_Prostate"
dataroot = os.path.join(data_root, msd_task)

# Experiment setup
test_path = "./"
work_dir = os.path.join(test_path, "workdir")
optuna_dir = './optuna_learningrate_grid'
da_output_yaml = os.path.join(work_dir, "datastats.yaml")
if not os.path.isdir(work_dir):
    os.makedirs(work_dir)

# Algorithm selected to do HPO. Refer to bundle history for the mapping between
# algorithm name and index
selected_algorithm_index = 0

### 3. Download one of MSD datasets

In [None]:

resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + msd_task + ".tar"
compressed_file = os.path.join(data_root, msd_task + ".tar")
if not os.path.exists(dataroot):
    os.makedirs(dataroot)
    download_and_extract(resource, compressed_file, data_root)

### 4. Generate input yaml and datafolds yaml. (User should generate their own)

In [None]:
datalist = {
    "training": [
        {
            "fold": 0,
            "image": "imagesTr/prostate_16.nii.gz",
            "label": "labelsTr/prostate_16.nii.gz"
        },
        {
            "fold": 0,
            "image": "imagesTr/prostate_04.nii.gz",
            "label": "labelsTr/prostate_04.nii.gz"
        },
        {
            "fold": 0,
            "image": "imagesTr/prostate_20.nii.gz",
            "label": "labelsTr/prostate_20.nii.gz"
        },
        {
            "fold": 0,
            "image": "imagesTr/prostate_43.nii.gz",
            "label": "labelsTr/prostate_43.nii.gz"
        },
        {
            "fold": 0,
            "image": "imagesTr/prostate_06.nii.gz",
            "label": "labelsTr/prostate_06.nii.gz"
        },
        {
            "fold": 0,
            "image": "imagesTr/prostate_14.nii.gz",
            "label": "labelsTr/prostate_14.nii.gz"
        },
        {
            "fold": 1,
            "image": "imagesTr/prostate_41.nii.gz",
            "label": "labelsTr/prostate_41.nii.gz"
        },
        {
            "fold": 1,
            "image": "imagesTr/prostate_34.nii.gz",
            "label": "labelsTr/prostate_34.nii.gz"
        },
        {
            "fold": 1,
            "image": "imagesTr/prostate_38.nii.gz",
            "label": "labelsTr/prostate_38.nii.gz"
        },
        {
            "fold": 1,
            "image": "imagesTr/prostate_10.nii.gz",
            "label": "labelsTr/prostate_10.nii.gz"
        },
        {
            "fold": 1,
            "image": "imagesTr/prostate_02.nii.gz",
            "label": "labelsTr/prostate_02.nii.gz"
        },
        {
            "fold": 1,
            "image": "imagesTr/prostate_24.nii.gz",
            "label": "labelsTr/prostate_24.nii.gz"
        },
        {
            "fold": 2,
            "image": "imagesTr/prostate_47.nii.gz",
            "label": "labelsTr/prostate_47.nii.gz"
        },
        {
            "fold": 2,
            "image": "imagesTr/prostate_28.nii.gz",
            "label": "labelsTr/prostate_28.nii.gz"
        },
        {
            "fold": 2,
            "image": "imagesTr/prostate_00.nii.gz",
            "label": "labelsTr/prostate_00.nii.gz"
        },
        {
            "fold": 2,
            "image": "imagesTr/prostate_42.nii.gz",
            "label": "labelsTr/prostate_42.nii.gz"
        },
        {
            "fold": 2,
            "image": "imagesTr/prostate_21.nii.gz",
            "label": "labelsTr/prostate_21.nii.gz"
        },
        {
            "fold": 2,
            "image": "imagesTr/prostate_17.nii.gz",
            "label": "labelsTr/prostate_17.nii.gz"
        },
        {
            "fold": 3,
            "image": "imagesTr/prostate_40.nii.gz",
            "label": "labelsTr/prostate_40.nii.gz"
        },
        {
            "fold": 3,
            "image": "imagesTr/prostate_31.nii.gz",
            "label": "labelsTr/prostate_31.nii.gz"
        },
        {
            "fold": 3,
            "image": "imagesTr/prostate_07.nii.gz",
            "label": "labelsTr/prostate_07.nii.gz"
        },
        {
            "fold": 3,
            "image": "imagesTr/prostate_35.nii.gz",
            "label": "labelsTr/prostate_35.nii.gz"
        },
        {
            "fold": 3,
            "image": "imagesTr/prostate_44.nii.gz",
            "label": "labelsTr/prostate_44.nii.gz"
        },
        {
            "fold": 3,
            "image": "imagesTr/prostate_39.nii.gz",
            "label": "labelsTr/prostate_39.nii.gz"
        },
        {
            "fold": 4,
            "image": "imagesTr/prostate_01.nii.gz",
            "label": "labelsTr/prostate_01.nii.gz"
        },
        {
            "fold": 4,
            "image": "imagesTr/prostate_13.nii.gz",
            "label": "labelsTr/prostate_13.nii.gz"
        },
        {
            "fold": 4,
            "image": "imagesTr/prostate_46.nii.gz",
            "label": "labelsTr/prostate_46.nii.gz"
        },
        {
            "fold": 4,
            "image": "imagesTr/prostate_25.nii.gz",
            "label": "labelsTr/prostate_25.nii.gz"
        },
        {
            "fold": 4,
            "image": "imagesTr/prostate_29.nii.gz",
            "label": "labelsTr/prostate_29.nii.gz"
        },
        {
            "fold": 4,
            "image": "imagesTr/prostate_37.nii.gz",
            "label": "labelsTr/prostate_37.nii.gz"
        }
    ]
}

input_dict = {
    "name": "Task05_Prostate",
    "task": "segmentation",
    "modality": "MRI",
    "datalist": datalist_filename,
    "dataroot": dataroot,
    "multigpu": True,
    "class_names": ["val_acc_pz", "val_acc_tz"]
}

with open(datalist_filename, 'w') as f:
    json.dump(datalist, f, indent=4)

with open(input_yaml, 'w') as f:
    yaml.dump(input_dict, f)


### 5. Create Bundle Generators


In [None]:

cfg = ConfigParser.load_config_file(input_yaml)
datalist = ConfigParser.load_config_file(datalist_filename)
# data analysis
if not os.path.exists(da_output_yaml):
    da = DataAnalyzer(datalist, dataroot, output_path=da_output_yaml)
    da.get_all_case_stats()

# algorithm generation
bundle_generator = BundleGen(
    algo_path=work_dir,
    data_stats_filename=da_output_yaml,
    data_src_cfg_name=input_yaml,
)

bundle_generator.generate(work_dir, num_fold=5)
history = bundle_generator.get_history()
export_bundle_algo_history(history)

### 6. Create Algo object from bundle_generator history

In [None]:
# you can get history from bundle_generator. It can also be acquired by reading bundles saved on disk

history = bundle_generator.get_history()
if len(history) == 0:
    history = import_bundle_algo_history(work_dir, only_trained=False)

algo_dict = history[selected_algorithm_index]
algo_name = list(algo_dict.keys())[selected_algorithm_index]
algo = algo_dict[algo_name]

# "override_params" is used to update algorithm hyperparameters 
# like num_epochs, which are not in the HPO search space. We set num_epochs=2
# to shorten the training time as an example
override_param = {
    "num_iterations": 8,
    "num_iterations_per_validation": 4,
    "num_images_per_batch": 2,
    "num_epochs": 2,
    "num_warmup_iterations": 4,
}



### 7. Create Optuna Generator class and overwrite get_hyperparameters() function

In [None]:
class OptunaGenLearningRate(OptunaGen):
    def get_hyperparameters(self):
        return {'learning_rate': self.trial.suggest_float("learning_rate", 0.00001, 0.1)}


optuna_gen = OptunaGenLearningRate(algo=algo, params=override_param)

### 8. Run Optuna optimization (with grid search)

In [None]:
search_space = {'learning_rate': [0.0001, 0.001, 0.01, 0.1]}
study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space), direction='maximize')
study.optimize(partial(optuna_gen, obj_filename=optuna_gen.get_obj_filename(), output_folder=optuna_dir), n_trials=2)
print("Best value: {} (params: {})\n".format(study.best_value, study.best_params))