From 56076e03e22cb305c4e515467b53b74e646b7077 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 14:43:33 +0000 Subject: [PATCH 1/2] fixes test Signed-off-by: Wenqi Li --- tests/test_download_and_extract.py | 19 +++---------------- tests/test_lr_finder.py | 18 ++++++++++-------- tests/utils.py | 17 +++++++++++++---- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index f896d4ae93..2a63a6b44e 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -16,7 +16,7 @@ from urllib.error import ContentTooShortError, HTTPError from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fail, skip_if_quick class TestDownloadAndExtract(unittest.TestCase): @@ -27,22 +27,15 @@ def test_actions(self): filepath = Path(testing_dir) / "MedNIST.tar.gz" output_dir = Path(testing_dir) md5_value = "0bc7306e7427e00ad1c5526a6677552d" - try: + with skip_if_downloading_fail(): download_and_extract(url, filepath, output_dir, md5_value) download_and_extract(url, filepath, output_dir, md5_value) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors wrong_md5 = "0" with self.assertLogs(logger="monai.apps", level="ERROR"): try: download_url(url, filepath, wrong_md5) except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) if isinstance(e, RuntimeError): # FIXME: skip MD5 check as current downloading method may fail self.assertTrue(str(e).startswith("md5 check")) @@ -56,7 +49,7 @@ def test_actions(self): @skip_if_quick def test_default(self): with tempfile.TemporaryDirectory() as tmp_dir: - try: + with skip_if_downloading_fail(): # icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing download_and_extract( "https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn", @@ -71,12 +64,6 @@ def test_default(self): hash_val="ac6e167ee40803577d98237f2b0241e5", file_type="zip", ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors if __name__ == "__main__": diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 78c94d4e41..b79f1d697f 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -24,6 +24,7 @@ from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils import optional_import, set_determinism +from tests.utils import skip_if_downloading_fail if TYPE_CHECKING: import matplotlib.pyplot as plt @@ -61,14 +62,15 @@ def setUp(self): def test_lr_finder(self): # 0.001 gives 54 examples - train_ds = MedNISTDataset( - root_dir=self.root_dir, - transform=self.transforms, - section="validation", - val_frac=0.001, - download=True, - num_workers=10, - ) + with skip_if_downloading_fail(): + train_ds = MedNISTDataset( + root_dir=self.root_dir, + transform=self.transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) num_classes = train_ds.get_num_classes() diff --git a/tests/utils.py b/tests/utils.py index 232c9c9030..19510cd908 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,6 +21,7 @@ import traceback import unittest import warnings +from contextlib import contextmanager from functools import partial from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple @@ -649,20 +650,28 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): ) -def download_url_or_skip_test(*args, **kwargs): - """``download_url`` and skip the tests if any downloading error occurs.""" +@contextmanager +def skip_if_downloading_fail(): try: - download_url(*args, **kwargs) - except (ContentTooShortError, HTTPError) as e: + yield + except (ContentTooShortError, HTTPError, ConnectionError) as e: raise unittest.SkipTest(f"error while downloading: {e}") from e except RuntimeError as rt_e: if "network issue" in str(rt_e): raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e if "gdown dependency" in str(rt_e): # no gdown installed raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "md5 check" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e raise rt_e +def download_url_or_skip_test(*args, **kwargs): + """``download_url`` and skip the tests if any downloading error occurs.""" + with skip_if_downloading_fail(): + download_url(*args, **kwargs) + + def query_memory(n=2): """ Find best n idle devices and return a string of device ids using the `nvidia-smi` command. From 803bf06828a1218d67eb838238e071483b6ec5ab Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 23:41:17 +0000 Subject: [PATCH 2/2] skip when downloading fails Signed-off-by: Wenqi Li --- tests/test_cross_validation.py | 11 +---- tests/test_decathlondataset.py | 11 +---- tests/test_download_and_extract.py | 6 +-- tests/test_efficientnet.py | 21 ++-------- tests/test_integration_classification_2d.py | 11 +---- tests/test_lr_finder.py | 4 +- tests/test_mednistdataset.py | 11 +---- tests/test_mmar_download.py | 17 ++------ tests/utils.py | 46 +++++++++------------ 9 files changed, 39 insertions(+), 99 deletions(-) diff --git a/tests/test_cross_validation.py b/tests/test_cross_validation.py index 72c7b53506..c378a52f78 100644 --- a/tests/test_cross_validation.py +++ b/tests/test_cross_validation.py @@ -11,12 +11,11 @@ import os import unittest -from urllib.error import ContentTooShortError, HTTPError from monai.apps import CrossValidation, DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestCrossValidation(unittest.TestCase): @@ -51,14 +50,8 @@ def _test_dataset(dataset): download=True, ) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = cvdataset.get_dataset(folds=0) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index a5d9ce3e27..744dccefaa 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -13,12 +13,11 @@ import shutil import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from monai.apps import DecathlonDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDecathlonDataset(unittest.TestCase): @@ -41,7 +40,7 @@ def _test_dataset(dataset): self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44)) - try: # will start downloading if testing_dir doesn't have the Decathlon files + with skip_if_downloading_fails(): data = DecathlonDataset( root_dir=testing_dir, task="Task04_Hippocampus", @@ -50,12 +49,6 @@ def _test_dataset(dataset): download=True, copy_cache=False, ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) data = DecathlonDataset( diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index 2a63a6b44e..435b280022 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -16,7 +16,7 @@ from urllib.error import ContentTooShortError, HTTPError from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_downloading_fail, skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick class TestDownloadAndExtract(unittest.TestCase): @@ -27,7 +27,7 @@ def test_actions(self): filepath = Path(testing_dir) / "MedNIST.tar.gz" output_dir = Path(testing_dir) md5_value = "0bc7306e7427e00ad1c5526a6677552d" - with skip_if_downloading_fail(): + with skip_if_downloading_fails(): download_and_extract(url, filepath, output_dir, md5_value) download_and_extract(url, filepath, output_dir, md5_value) @@ -49,7 +49,7 @@ def test_actions(self): @skip_if_quick def test_default(self): with tempfile.TemporaryDirectory() as tmp_dir: - with skip_if_downloading_fail(): + with skip_if_downloading_fails(): # icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing download_and_extract( "https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn", diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 0ab383fd56..a2a5e30750 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -13,7 +13,6 @@ import unittest from typing import TYPE_CHECKING from unittest import skipUnless -from urllib.error import ContentTooShortError, HTTPError import torch from parameterized import parameterized @@ -27,7 +26,7 @@ get_efficientnet_image_size, ) from monai.utils import optional_import -from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save if TYPE_CHECKING: import torchvision @@ -251,12 +250,8 @@ class TestEFFICIENTNET(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBN(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): @@ -269,12 +264,8 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_non_default_shapes(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBN(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # override input shape with different variations num_dims = len(input_shape) - 2 @@ -387,12 +378,8 @@ class TestExtractFeatures(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shapes): device = "cuda" if torch.cuda.is_available() else "cpu" - try: - # initialize model + with skip_if_downloading_fails(): net = EfficientNetBNFeatures(**input_param).to(device) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - return # skipping the tests because of http errors # run inference with random tensor with eval_mode(net): diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 2c0c9e1f2e..3572678b64 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -12,7 +12,6 @@ import os import unittest import warnings -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -39,7 +38,7 @@ ) from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick TEST_DATA_URL = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE" MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" @@ -186,14 +185,8 @@ def setUp(self): dataset_file = os.path.join(self.data_dir, "MedNIST.tar.gz") if not os.path.exists(data_dir): - try: + with skip_if_downloading_fails(): download_and_extract(TEST_DATA_URL, dataset_file, self.data_dir, MD5_VALUE) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors assert os.path.exists(data_dir) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index b79f1d697f..a76808be20 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -24,7 +24,7 @@ from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils import optional_import, set_determinism -from tests.utils import skip_if_downloading_fail +from tests.utils import skip_if_downloading_fails if TYPE_CHECKING: import matplotlib.pyplot as plt @@ -62,7 +62,7 @@ def setUp(self): def test_lr_finder(self): # 0.001 gives 54 examples - with skip_if_downloading_fail(): + with skip_if_downloading_fails(): train_ds = MedNISTDataset( root_dir=self.root_dir, transform=self.transforms, diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 1da3b73de2..e7cc1a60ff 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -13,12 +13,11 @@ import shutil import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from monai.apps import MedNISTDataset from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils.enums import PostFix -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick MEDNIST_FULL_DATASET_LENGTH = 58954 @@ -43,16 +42,10 @@ def _test_dataset(dataset): self.assertTrue(PostFix.meta("image") in dataset[0]) self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) - try: # will start downloading if testing_dir doesn't have the MedNIST files + with skip_if_downloading_fails(): data = MedNISTDataset( root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False ) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors _test_dataset(data) diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 2cae5969db..cab051e781 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -13,7 +13,6 @@ import tempfile import unittest from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -22,7 +21,7 @@ from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar from monai.apps.mmars import MODEL_DESC from monai.apps.mmars.mmars import _get_val -from tests.utils import skip_if_quick +from tests.utils import skip_if_downloading_fails, skip_if_quick TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]] TEST_EXTRACT_CASES = [ @@ -105,7 +104,7 @@ class TestMMMARDownload(unittest.TestCase): @parameterized.expand(TEST_CASES) @skip_if_quick def test_download(self, idx): - try: + with skip_if_downloading_fails(): # test model specification cand = get_model_spec(idx) self.assertEqual(cand[RemoteMMARKeys.ID], idx) @@ -116,22 +115,12 @@ def test_download(self, idx): download_mmar(idx, mmar_dir=tmp_dir, progress=False) download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx))) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return # skipping this test due the network connection errors @parameterized.expand(TEST_EXTRACT_CASES) @skip_if_quick def test_load_ckpt(self, input_args, expected_name, expected_val): - try: + with skip_if_downloading_fails(): output = load_from_mmar(**input_args) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - print(str(e)) - if isinstance(e, HTTPError): - self.assertTrue("500" in str(e)) # http error has the code 500 - return self.assertEqual(output.__class__.__name__, expected_name) x = next(output.parameters()) # verify the first element np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3) diff --git a/tests/utils.py b/tests/utils.py index 19510cd908..0af019f0b0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -25,7 +25,7 @@ from functools import partial from subprocess import PIPE, Popen from typing import Callable, Optional, Tuple -from urllib.error import ContentTooShortError, HTTPError, URLError +from urllib.error import ContentTooShortError, HTTPError import numpy as np import torch @@ -94,17 +94,25 @@ def assert_allclose( np.testing.assert_allclose(actual, desired, *args, **kwargs) -def test_pretrained_networks(network, input_param, device): +@contextmanager +def skip_if_downloading_fails(): try: + yield + except (ContentTooShortError, HTTPError, ConnectionError) as e: + raise unittest.SkipTest(f"error while downloading: {e}") from e + except RuntimeError as rt_e: + if "network issue" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "gdown dependency" in str(rt_e): # no gdown installed + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + if "md5 check" in str(rt_e): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e + raise rt_e + + +def test_pretrained_networks(network, input_param, device): + with skip_if_downloading_fails(): return network(**input_param).to(device) - except (URLError, HTTPError) as e: - raise unittest.SkipTest(e) from e - except RuntimeError as r_error: - if "unexpected EOF" in f"{r_error}": # The file might be corrupted. - raise unittest.SkipTest(f"{r_error}") from r_error - if "network issue" in f"{r_error}": # The network is not available. - raise unittest.SkipTest(f"{r_error}") from r_error - raise def test_is_quick(): @@ -650,25 +658,9 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): ) -@contextmanager -def skip_if_downloading_fail(): - try: - yield - except (ContentTooShortError, HTTPError, ConnectionError) as e: - raise unittest.SkipTest(f"error while downloading: {e}") from e - except RuntimeError as rt_e: - if "network issue" in str(rt_e): - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - if "gdown dependency" in str(rt_e): # no gdown installed - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - if "md5 check" in str(rt_e): - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e - raise rt_e - - def download_url_or_skip_test(*args, **kwargs): """``download_url`` and skip the tests if any downloading error occurs.""" - with skip_if_downloading_fail(): + with skip_if_downloading_fails(): download_url(*args, **kwargs)