Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class BundleAlgo(Algo):

from monai.apps.auto3dseg import BundleAlgo

data_stats_yaml = "/workspace/datastats.yaml"
algo = BundleAlgo(template_path=../algorithms/templates/segresnet2d/configs)
data_stats_yaml = "../datastats.yaml"
algo = BundleAlgo(template_path="../algorithm_templates")
algo.set_data_stats(data_stats_yaml)
# algo.set_data_src("../data_src.json")
algo.export_to_disk(".", algo_name="segresnet2d_1")
Expand All @@ -69,7 +69,8 @@ def __init__(self, template_path: PathLike):
Create an Algo instance based on the predefined Algo template.

Args:
template_path: path to the root of the algo template.
template_path: path to a folder that contains the algorithm templates.
Please check https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates

"""

Expand Down Expand Up @@ -154,7 +155,8 @@ def export_to_disk(self, output_path: str, algo_name: str, **kwargs: Any) -> Non
os.makedirs(self.output_path, exist_ok=True)
if os.path.isdir(self.output_path):
shutil.rmtree(self.output_path)
shutil.copytree(str(self.template_path), self.output_path)
# copy algorithm_templates/<Algo> to the working directory output_path
shutil.copytree(os.path.join(str(self.template_path), self.name), self.output_path)
else:
self.output_path = str(self.template_path)
if kwargs.pop("fill_template", True):
Expand Down Expand Up @@ -342,10 +344,10 @@ def get_output_path(self):

# default algorithms
default_algos = {
"segresnet2d": dict(_target_="segresnet2d.scripts.algo.Segresnet2dAlgo", template_path="segresnet2d"),
"dints": dict(_target_="dints.scripts.algo.DintsAlgo", template_path="dints"),
"swinunetr": dict(_target_="swinunetr.scripts.algo.SwinunetrAlgo", template_path="swinunetr"),
"segresnet": dict(_target_="segresnet.scripts.algo.SegresnetAlgo", template_path="segresnet"),
"segresnet2d": dict(_target_="segresnet2d.scripts.algo.Segresnet2dAlgo"),
"dints": dict(_target_="dints.scripts.algo.DintsAlgo"),
"swinunetr": dict(_target_="swinunetr.scripts.algo.SwinunetrAlgo"),
"segresnet": dict(_target_="segresnet.scripts.algo.SegresnetAlgo"),
}


Expand Down Expand Up @@ -377,7 +379,7 @@ def _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]:

algos_all = deepcopy(default_algos)
for name in algos_all:
algos_all[name]["template_path"] = os.path.join(at_path, algos_all[name]["template_path"])
algos_all[name]["template_path"] = at_path

return algos_all

Expand All @@ -398,9 +400,7 @@ def _copy_algos_folder(folder, at_path):
algos_all = {}
for name in os.listdir(at_path):
if os.path.exists(os.path.join(folder, name, "scripts", "algo.py")):
algos_all[name] = dict(
_target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=os.path.join(at_path, name)
)
algos_all[name] = dict(_target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=at_path)
logger.info(f"Copying template: {name} -- {algos_all[name]}")
if not algos_all:
raise ValueError(f"Unable to find any algos in {folder}")
Expand Down Expand Up @@ -463,7 +463,7 @@ def __init__(
self.algos: Any = []
if isinstance(algos, dict):
for algo_name, algo_params in sorted(algos.items()):
template_path = os.path.dirname(algo_params.get("template_path", "."))
template_path = algo_params.get("template_path", ".")
if len(template_path) > 0 and template_path not in sys.path:
sys.path.append(template_path)

Expand All @@ -486,7 +486,7 @@ def __init__(
raise ValueError("Unexpected error algos is not a dict")

self.data_stats_filename = data_stats_filename
self.data_src_cfg_filename = data_src_cfg_name
self.data_src_cfg_name = data_src_cfg_name
self.history: list[dict] = []

def set_data_stats(self, data_stats_filename: str) -> None:
Expand All @@ -502,18 +502,18 @@ def get_data_stats(self):
"""Get the filename of the data stats"""
return self.data_stats_filename

def set_data_src(self, data_src_cfg_filename):
def set_data_src(self, data_src_cfg_name):
"""
Set the data source filename

Args:
data_src_cfg_filename: filename of data_source file
data_src_cfg_name: filename of data_source file
"""
self.data_src_cfg_filename = data_src_cfg_filename
self.data_src_cfg_name = data_src_cfg_name

def get_data_src(self):
"""Get the data source filename"""
return self.data_src_cfg_filename
return self.data_src_cfg_name

def get_history(self) -> list:
"""Get the history of the bundleAlgo object with their names/identifiers"""
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsRe
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
3 changes: 3 additions & 0 deletions tests/test_auto3dseg_bundlegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import os
import shutil
import sys
import tempfile
import unittest

Expand Down Expand Up @@ -126,6 +127,7 @@ def test_move_bundle_gen_folder(self) -> None:
data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")
ConfigParser.export_config_file(data_src, data_src_cfg)

sys_path = sys.path.copy()
with skip_if_downloading_fails():
bundle_generator = BundleGen(
algo_path=work_dir,
Expand All @@ -138,6 +140,7 @@ def test_move_bundle_gen_folder(self) -> None:
history_before = bundle_generator.get_history()
export_bundle_algo_history(history_before)

sys.path = sys_path # prevent the import_bundle_algo_history from using the path "work_dir/algorithm_templates"
tempfile.TemporaryDirectory()
work_dir_new = os.path.join(test_path, "workdir_2")
shutil.move(work_dir, work_dir_new)
Expand Down