Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions monai/apps/nnunet/nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import glob
import os
import re
import shlex
import subprocess
from typing import Any
Expand All @@ -34,6 +35,8 @@

__all__ = ["nnUNetV2Runner"]

DATASET_ID_FORMAT = r"Dataset[0-9]{3}|[0-9]+" # regex format for a valid nnUnet dataset name


class nnUNetV2Runner: # noqa: N801
"""
Expand Down Expand Up @@ -195,6 +198,13 @@ def __init__(

# dataset_name_or_id has to be a string
self.dataset_name_or_id = str(self.input_info.pop("dataset_name_or_id", 1))
self.dataset_name: str | None = None

# ensure the dataset name is a single identifier/number, this prevents code injection when composing commands
if re.fullmatch(DATASET_ID_FORMAT, self.dataset_name_or_id) is None:
raise ValueError(
f"Value for dataset_name_or_id `{self.dataset_name_or_id}` not a valid dataset name or ID."
)

try:
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
Expand Down Expand Up @@ -239,7 +249,7 @@ def convert_dataset(self):

from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name

self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id))
self.dataset_name = maybe_convert_to_dataset_name(self.dataset_name_or_id)

datalist_json = ConfigParser.load_config_file(self.input_info.pop("datalist"))

Expand Down Expand Up @@ -548,7 +558,7 @@ def train_single_model_command(
Raises:
ValueError: If gpu_id is an empty tuple or list.
"""
env = os.environ.copy()
env: dict[str, str] = os.environ.copy()
device_setting: str = "0"
num_gpus = 1
if isinstance(gpu_id, str):
Expand All @@ -574,22 +584,25 @@ def train_single_model_command(

cmd = [
"nnUNetv2_train",
f"{self.dataset_name_or_id}",
f"{config}",
f"{fold}",
self.dataset_name_or_id,
config,
fold,
"-tr",
f"{self.trainer_class_name}",
self.trainer_class_name,
"-num_gpus",
f"{num_gpus}",
num_gpus,
]

if self.export_validation_probabilities:
cmd.append("--npz")

for _key, _value in kwargs.items():
if _key == "p" or _key == "pretrained_weights":
cmd.extend([f"-{_key}", f"{_value}"])
else:
cmd.extend([f"--{_key}", f"{_value}"])
return cmd, env
prefix = "-" if _key in {"p", "pretrained_weights"} else "--"
cmd += [f"{prefix}{_key}", str(_value)]

cmd_str: list[str] = [str(c) for c in cmd]

return cmd_str, env

def train(
self,
Expand Down Expand Up @@ -641,7 +654,14 @@ def train_parallel_cmd(
None (all available GPUs).
kwargs: this optional parameter allows you to specify additional arguments defined in the
``train_single_model`` method.

Raises:
ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one.
"""

if self.dataset_name is None:
raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.")

# unpack compressed files
folder_names = []
for root, _, files in os.walk(os.path.join(self.nnunet_preprocessed, self.dataset_name)):
Expand Down Expand Up @@ -696,7 +716,14 @@ def train_parallel(
None (all available GPUs).
kwargs: this optional parameter allows you to specify additional arguments defined in the
``train_single_model`` method.

Raises:
ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one.
"""

if self.dataset_name is None:
raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.")

all_cmds = self.train_parallel_cmd(configs=configs, gpu_id_for_all=gpu_id_for_all, **kwargs)
for s, cmds in enumerate(all_cmds):
for gpu_id, gpu_cmd in cmds.items():
Expand Down Expand Up @@ -908,7 +935,14 @@ def predict_ensemble_postprocessing(
run_postprocessing: whether to conduct post-processing
kwargs: this optional parameter allows you to specify additional arguments defined in the
``predict`` method.

Raises:
ValueError: self.dataset_name must have a value, ie. when using an existing dataset or after creating one.
"""

if self.dataset_name is None:
raise ValueError(f"A valid dataset name must be given in {self.dataset_name=}.")

from nnunetv2.ensembling.ensemble import ensemble_folders
from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder
from nnunetv2.utilities.file_path_utilities import get_output_folder
Expand Down
75 changes: 75 additions & 0 deletions tests/integration/test_integration_nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@

from __future__ import annotations

import logging
import os
import tempfile
import unittest
from textwrap import dedent

import nibabel as nib
import numpy as np

import monai.apps.nnunet.nnunetv2_runner
from monai.apps.nnunet import nnUNetV2Runner
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
Expand All @@ -27,6 +30,8 @@
_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
_, has_nnunet = optional_import("nnunetv2")

monai.apps.nnunet.nnunetv2_runner.logger.setLevel(logging.ERROR) # suppress warning logging to clean up test output

sim_datalist: dict[str, list[dict]] = {
"testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
"training": [
Expand Down Expand Up @@ -91,5 +96,75 @@ def tearDown(self) -> None:
self.test_dir.cleanup()


@skip_if_quick
@unittest.skipIf(not has_nnunet, "no nnunetv2")
class TestnnUNetV2RunnerSecurity(unittest.TestCase):
def setUp(self) -> None:
self.test_dir = tempfile.TemporaryDirectory()
test_path = self.test_dir.name

self.good_yml1 = os.path.join(test_path, "good1.yml")
self.good_yml2 = os.path.join(test_path, "good2.yml")
self.inject_yml = os.path.join(test_path, "test.yml")

good_yml_content1 = """
dataset_name_or_id: Dataset123
dataroot: ./data
datalist: ./lists/task4.json
work_dir: ./work
nnunet_raw: ./nnUNet_raw
nnunet_preprocessed: ./nnUNet_preprocessed
nnunet_results: ./nnUNet_results
"""

with open(self.good_yml1, "w") as o:
o.write(dedent(good_yml_content1))

good_yml_content2 = """
dataset_name_or_id: 123
dataroot: ./data
datalist: ./lists/task4.json
work_dir: ./work
nnunet_raw: ./nnUNet_raw
nnunet_preprocessed: ./nnUNet_preprocessed
nnunet_results: ./nnUNet_results
"""

with open(self.good_yml2, "w") as o:
o.write(dedent(good_yml_content2))

# define a config file with code-injecting dataset name
injecting_yml_content = """
dataset_name_or_id: '4 & echo "This is exploited" > "./test.txt" & rem'
dataroot: ./data
datalist: ./lists/task4.json
work_dir: ./work
nnunet_raw: ./nnUNet_raw
nnunet_preprocessed: ./nnUNet_preprocessed
nnunet_results: ./nnUNet_results
"""

with open(self.inject_yml, "w") as o:
o.write(dedent(injecting_yml_content))
Comment thread
ericspod marked this conversation as resolved.

def test_nnunetv2runner_good_dataset_name(self) -> None:
"""
Test the dataset name given must conform to the nnUNet requirement of being an int or "Dataset###".
"""
for ds in [self.good_yml1, self.good_yml2]:
with self.subTest(f"Testing {os.path.basename(ds)}"):
nnUNetV2Runner(input_config=ds, trainer_class_name="nnUNetTrainer")

def test_nnunetv2runner_bad_dataset_name(self) -> None:
"""
Test the dataset name given must conform to the nnUNet requirement of being an int or "Dataset###".
"""
with self.assertRaises(ValueError):
nnUNetV2Runner(input_config=self.inject_yml, trainer_class_name="nnUNetTrainer")

def tearDown(self) -> None:
self.test_dir.cleanup()


if __name__ == "__main__":
unittest.main()
Loading