Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ jobs:
python -m pip list
python setup.py develop # test no compile installation
shell: bash
- if: runner.os != 'windows'
name: Run compiled (${{ runner.os }})
- name: Run compiled (${{ runner.os }})
run: |
python setup.py develop --uninstall
BUILD_MONAI=1 python setup.py develop # compile the cpp extensions
Expand Down
17 changes: 13 additions & 4 deletions tests/test_auto3dseg_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from monai.apps.auto3dseg import AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder, BundleGen, DataAnalyzer
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.utils import optional_import
from monai.utils import optional_import, set_determinism
from monai.utils.enums import AlgoEnsembleKeys
from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick

Expand Down Expand Up @@ -68,6 +68,7 @@
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestEnsembleBuilder(unittest.TestCase):
def setUp(self) -> None:
set_determinism(0)
self.test_dir = tempfile.TemporaryDirectory()

@skip_if_no_cuda
Expand Down Expand Up @@ -126,12 +127,19 @@ def test_ensemble(self) -> None:

for h in history:
self.assertEqual(len(h.keys()), 1, "each record should have one model")
for _, algo in h.items():
algo.train(train_param)
for name, algo in h.items():
_train_param = train_param.copy()
if name.startswith("segresnet"):
_train_param["network#init_filters"] = 8
_train_param["pretrained_ckpt_name"] = ""
elif name.startswith("swinunetr"):
_train_param["network#feature_size"] = 12
algo.train(_train_param)

builder = AlgoEnsembleBuilder(history, data_src_cfg)
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=2))
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=1))
ensemble = builder.get_ensemble()
pred_param["network#init_filter"] = 8 # segresnet
preds = ensemble(pred_param)
self.assertTupleEqual(preds[0].shape, (2, 24, 24, 24))

Expand All @@ -141,6 +149,7 @@ def test_ensemble(self) -> None:
print(algo[AlgoEnsembleKeys.ID])

def tearDown(self) -> None:
set_determinism(None)
self.test_dir.cleanup()


Expand Down