diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 72c7b53506..c378a52f78 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -11,12 +11,11 @@ import os import unittest -from urllib.error import ContentTooShortError, HTTPError from monai.apps import CrossValidation, DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestCrossValidation(unittest.TestCase): @@ -51,14 +50,8 @@ def _test_dataset(dataset): download=True, ) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = cvdataset.get_dataset(folds=0) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index a5d9ce3e27..744dccefaa 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -13,12 +13,11 @@ import shutil import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from monai.apps import DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDecathlonDataset(unittest.TestCase): @@ -41,7 +40,7 @@ def _test_dataset(dataset): self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44)) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", @@ -50,12 +49,6 @@ def _test_dataset(dataset): download=True, copy_cache=False, ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) data = DecathlonDataset( diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index f896d4ae93..435b280022 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -16,7 +16,7 @@ from urllib.error import ContentTooShortError, HTTPError from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDownloadAndExtract(unittest.TestCase): @@ -27,22 +27,15 @@ def test_actions(self): filepath = Path(testing_dir) / "MedNIST.tar.gz" output_dir = Path(testing_dir) md5_value = "0bc7306e7427e00ad1c5526a6677552d" - try: + with skip_if_downloading_fails(): download_and_extract(url, filepath, output_dir, md5_value) download_and_extract(url, filepath, output_dir, md5_value) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors wrong_md5 = "0" with self.assertLogs(logger="monai.apps", level="ERROR"): try: download_url(url, filepath, wrong_md5) except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) if isinstance(e, RuntimeError): # FIXME: skip MD5 check as current downloading method may fail self.assertTrue(str(e).startswith("md5 check")) @@ -56,7 +49,7 @@ def test_actions(self): @skip_if_quick def test_default(self): with tempfile.TemporaryDirectory() as tmp_dir: - try: + with skip_if_downloading_fails(): # icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing download_and_extract( "https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn", @@ -71,12 +64,6 @@ def test_default(self): hash_val="ac6e167ee40803577d98237f2b0241e5", file_type="zip", ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors if __name__ == "__main__": diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 0ab383fd56..a2a5e30750 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -13,7 +13,6 @@ import unittest from typing import TYPE_CHECKING from unittest import skipUnless -from urllib.error import ContentTooShortError, HTTPError import torch from parameterized import parameterized @@ -27,7 +26,7 @@ get_efficientnet_image_size, ) from monai.utils import optional_import -from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save if TYPE_CHECKING: import torchvision @@ -251,12 +250,8 @@ class TestEFFICIENTNET(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBN(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): @@ -269,12 +264,8 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_non_default_shapes(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBN(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # override input shape with different variations num_dims = len(input_shape) - 2 @@ -387,12 +378,8 @@ class TestExtractFeatures(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBNFeatures(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 2c0c9e1f2e..3572678b64 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -12,7 +12,6 @@ import os import unittest import warnings -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -39,7 +38,7 @@ ) from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick TEST_DATA_URL = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" @@ -186,14 +185,8 @@ def setUp(self): dataset_file = os.path.join(self.data_dir, "MedNIST.tar.gz") if not os.path.exists(data_dir): - try: + with skip_if_downloading_fails(): download_and_extract(TEST_DATA_URL, dataset_file, self.data_dir, MD5_VALUE) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors assert os.path.exists(data_dir) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 78c94d4e41..a76808be20 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -24,6 +24,7 @@ from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils import optional_import, set_determinism +from tests.utils import skip_if_downloading_fails if TYPE_CHECKING: import matplotlib.pyplot as plt @@ -61,14 +62,15 @@ def setUp(self): def test_lr_finder(self): # 0.001 gives 54 examples - train_ds = MedNISTDataset( - root_dir=self.root_dir, - transform=self.transforms, - section="validation", - val_frac=0.001, - download=True, - num_workers=10, - ) + with skip_if_downloading_fails(): + train_ds = MedNISTDataset( + root_dir=self.root_dir, + transform=self.transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) num_classes = train_ds.get_num_classes() diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 1da3b73de2..e7cc1a60ff 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -13,12 +13,11 @@ import shutil import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from monai.apps import MedNISTDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick MEDNIST_FULL_DATASET_LENGTH = 58954 @@ -43,16 +42,10 @@ def _test_dataset(dataset): self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) - try: # will start downloading if testing_dir doesn't have the MedNIST files + with skip_if_downloading_fails(): data = MedNISTDataset( root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 2cae5969db..cab051e781 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -13,7 +13,6 @@ import tempfile import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -22,7 +21,7 @@ from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from monai.apps.mmars import MODEL_DESC from monai.apps.mmars.mmars import _get_val -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]] TEST_EXTRACT_CASES = [ @@ -105,7 +104,7 @@ class TestMMMARDownload(unittest.TestCase): @parameterized.expand(TEST_CASES) @skip_if_quick def test_download(self, idx): - try: + with skip_if_downloading_fails(): # test model specification cand = get_model_spec(idx) self.assertEqual(cand[RemoteMMARKeys.ID], idx) @@ -116,22 +115,12 @@ def test_download(self, idx): download_mmar(idx, mmar_dir=tmp_dir, progress=False) download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return # skipping this test due the network connection errors @parameterized.expand(TEST_EXTRACT_CASES) @skip_if_quick def test_load_ckpt(self, input_args, expected_name, expected_val): - try: + with skip_if_downloading_fails(): output = load_from_mmar(**input_args) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return self.assertEqual(output.__class__.__name__, expected_name) x = next(output.parameters()) # verify the first element np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3) diff --git a/tests/utils.py b/tests/utils.py index 232c9c9030..0af019f0b0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,10 +21,11 @@ import traceback import unittest import warnings +from contextlib import contextmanager from functools import partial from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple -from urllib.error import ContentTooShortError, HTTPError, URLError +from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -93,17 +94,25 @@ def assert_allclose( np.testing.assert_allclose(actual, desired, *args, **kwargs) -def test_pretrained_networks(network, input_param, device): +@contextmanager +def skip_if_downloading_fails(): try: + yield + except (ContentTooShortError, HTTPError, ConnectionError) as e: + raise unittest.SkipTest(f"error while downloading: {e}") from e + except RuntimeError as rt_e: + if "network issue" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "gdown dependency" in str(rt_e): # no gdown installed + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "md5 check" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + raise rt_e + + +def test_pretrained_networks(network, input_param, device): + with skip_if_downloading_fails(): return network(**input_param).to(device) - except (URLError, HTTPError) as e: - raise unittest.SkipTest(e) from e - except RuntimeError as r_error: - if "unexpected EOF" in f"{r_error}": # The file might be corrupted. - raise unittest.SkipTest(f"{r_error}") from r_error - if "network issue" in f"{r_error}": # The network is not available. - raise unittest.SkipTest(f"{r_error}") from r_error - raise def test_is_quick(): @@ -651,16 +660,8 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): def download_url_or_skip_test(*args, **kwargs): """``download_url`` and skip the tests if any downloading error occurs.""" - try: + with skip_if_downloading_fails(): download_url(*args, **kwargs) - except (ContentTooShortError, HTTPError) as e: - raise unittest.SkipTest(f"error while downloading: {e}") from e - except RuntimeError as rt_e: - if "network issue" in str(rt_e): - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - if "gdown dependency" in str(rt_e): # no gdown installed - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - raise rt_e def query_memory(n=2):