Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@

import os
import unittest
from urllib.error import ContentTooShortError, HTTPError

from monai.apps import CrossValidation, DecathlonDataset
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.utils.enums import PostFix
from tests.utils import skip_if_quick
from tests.utils import skip_if_downloading_fails, skip_if_quick


class TestCrossValidation(unittest.TestCase):
Expand Down Expand Up @@ -51,14 +50,8 @@ def _test_dataset(dataset):
download=True,
)

try: # will start downloading if testing_dir doesn't have the Decathlon files
with skip_if_downloading_fails():
data = cvdataset.get_dataset(folds=0)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

_test_dataset(data)

Expand Down
11 changes: 2 additions & 9 deletions tests/test_decathlondataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
import shutil
import unittest
from pathlib import Path
from urllib.error import ContentTooShortError, HTTPError

from monai.apps import DecathlonDataset
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.utils.enums import PostFix
from tests.utils import skip_if_quick
from tests.utils import skip_if_downloading_fails, skip_if_quick


class TestDecathlonDataset(unittest.TestCase):
Expand All @@ -41,7 +40,7 @@ def _test_dataset(dataset):
self.assertTrue(PostFix.meta("image") in dataset[0])
self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44))

try: # will start downloading if testing_dir doesn't have the Decathlon files
with skip_if_downloading_fails():
data = DecathlonDataset(
root_dir=testing_dir,
task="Task04_Hippocampus",
Expand All @@ -50,12 +49,6 @@ def _test_dataset(dataset):
download=True,
copy_cache=False,
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

_test_dataset(data)
data = DecathlonDataset(
Expand Down
19 changes: 3 additions & 16 deletions tests/test_download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from urllib.error import ContentTooShortError, HTTPError

from monai.apps import download_and_extract, download_url, extractall
from tests.utils import skip_if_quick
from tests.utils import skip_if_downloading_fails, skip_if_quick


class TestDownloadAndExtract(unittest.TestCase):
Expand All @@ -27,22 +27,15 @@ def test_actions(self):
filepath = Path(testing_dir) / "MedNIST.tar.gz"
output_dir = Path(testing_dir)
md5_value = "0bc7306e7427e00ad1c5526a6677552d"
try:
with skip_if_downloading_fails():
download_and_extract(url, filepath, output_dir, md5_value)
download_and_extract(url, filepath, output_dir, md5_value)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

wrong_md5 = "0"
with self.assertLogs(logger="monai.apps", level="ERROR"):
try:
download_url(url, filepath, wrong_md5)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
Expand All @@ -56,7 +49,7 @@ def test_actions(self):
@skip_if_quick
def test_default(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
with skip_if_downloading_fails():
# icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing
download_and_extract(
"https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn",
Expand All @@ -71,12 +64,6 @@ def test_default(self):
hash_val="ac6e167ee40803577d98237f2b0241e5",
file_type="zip",
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors


if __name__ == "__main__":
Expand Down
21 changes: 4 additions & 17 deletions tests/test_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import unittest
from typing import TYPE_CHECKING
from unittest import skipUnless
from urllib.error import ContentTooShortError, HTTPError

import torch
from parameterized import parameterized
Expand All @@ -27,7 +26,7 @@
get_efficientnet_image_size,
)
from monai.utils import optional_import
from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save
from tests.utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save

if TYPE_CHECKING:
import torchvision
Expand Down Expand Up @@ -251,12 +250,8 @@ class TestEFFICIENTNET(unittest.TestCase):
def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
# initialize model
with skip_if_downloading_fails():
net = EfficientNetBN(**input_param).to(device)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
return # skipping the tests because of http errors

# run inference with random tensor
with eval_mode(net):
Expand All @@ -269,12 +264,8 @@ def test_shape(self, input_param, input_shape, expected_shape):
def test_non_default_shapes(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
# initialize model
with skip_if_downloading_fails():
net = EfficientNetBN(**input_param).to(device)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
return # skipping the tests because of http errors

# override input shape with different variations
num_dims = len(input_shape) - 2
Expand Down Expand Up @@ -387,12 +378,8 @@ class TestExtractFeatures(unittest.TestCase):
def test_shape(self, input_param, input_shape, expected_shapes):
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
# initialize model
with skip_if_downloading_fails():
net = EfficientNetBNFeatures(**input_param).to(device)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
return # skipping the tests because of http errors

# run inference with random tensor
with eval_mode(net):
Expand Down
11 changes: 2 additions & 9 deletions tests/test_integration_classification_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import unittest
import warnings
from urllib.error import ContentTooShortError, HTTPError

import numpy as np
import torch
Expand All @@ -39,7 +38,7 @@
)
from monai.utils import set_determinism
from tests.testing_data.integration_answers import test_integration_value
from tests.utils import DistTestCase, TimedCall, skip_if_quick
from tests.utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick

TEST_DATA_URL = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d"
Expand Down Expand Up @@ -186,14 +185,8 @@ def setUp(self):
dataset_file = os.path.join(self.data_dir, "MedNIST.tar.gz")

if not os.path.exists(data_dir):
try:
with skip_if_downloading_fails():
download_and_extract(TEST_DATA_URL, dataset_file, self.data_dir, MD5_VALUE)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

assert os.path.exists(data_dir)

Expand Down
18 changes: 10 additions & 8 deletions tests/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.optimizers import LearningRateFinder
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.utils import optional_import, set_determinism
from tests.utils import skip_if_downloading_fails

if TYPE_CHECKING:
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -61,14 +62,15 @@ def setUp(self):

def test_lr_finder(self):
# 0.001 gives 54 examples
train_ds = MedNISTDataset(
root_dir=self.root_dir,
transform=self.transforms,
section="validation",
val_frac=0.001,
download=True,
num_workers=10,
)
with skip_if_downloading_fails():
train_ds = MedNISTDataset(
root_dir=self.root_dir,
transform=self.transforms,
section="validation",
val_frac=0.001,
download=True,
num_workers=10,
)
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
num_classes = train_ds.get_num_classes()

Expand Down
11 changes: 2 additions & 9 deletions tests/test_mednistdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
import shutil
import unittest
from pathlib import Path
from urllib.error import ContentTooShortError, HTTPError

from monai.apps import MedNISTDataset
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.utils.enums import PostFix
from tests.utils import skip_if_quick
from tests.utils import skip_if_downloading_fails, skip_if_quick

MEDNIST_FULL_DATASET_LENGTH = 58954

Expand All @@ -43,16 +42,10 @@ def _test_dataset(dataset):
self.assertTrue(PostFix.meta("image") in dataset[0])
self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

try: # will start downloading if testing_dir doesn't have the MedNIST files
with skip_if_downloading_fails():
data = MedNISTDataset(
root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

_test_dataset(data)

Expand Down
17 changes: 3 additions & 14 deletions tests/test_mmar_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import tempfile
import unittest
from pathlib import Path
from urllib.error import ContentTooShortError, HTTPError

import numpy as np
import torch
Expand All @@ -22,7 +21,7 @@
from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
from monai.apps.mmars import MODEL_DESC
from monai.apps.mmars.mmars import _get_val
from tests.utils import skip_if_quick
from tests.utils import skip_if_downloading_fails, skip_if_quick

TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]]
TEST_EXTRACT_CASES = [
Expand Down Expand Up @@ -105,7 +104,7 @@ class TestMMMARDownload(unittest.TestCase):
@parameterized.expand(TEST_CASES)
@skip_if_quick
def test_download(self, idx):
try:
with skip_if_downloading_fails():
# test model specification
cand = get_model_spec(idx)
self.assertEqual(cand[RemoteMMARKeys.ID], idx)
Expand All @@ -116,22 +115,12 @@ def test_download(self, idx):
download_mmar(idx, mmar_dir=tmp_dir, progress=False)
download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching
self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx)))
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, HTTPError):
self.assertTrue("500" in str(e)) # http error has the code 500
return # skipping this test due the network connection errors

@parameterized.expand(TEST_EXTRACT_CASES)
@skip_if_quick
def test_load_ckpt(self, input_args, expected_name, expected_val):
try:
with skip_if_downloading_fails():
output = load_from_mmar(**input_args)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, HTTPError):
self.assertTrue("500" in str(e)) # http error has the code 500
return
self.assertEqual(output.__class__.__name__, expected_name)
x = next(output.parameters()) # verify the first element
np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3)
Expand Down
39 changes: 20 additions & 19 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
import traceback
import unittest
import warnings
from contextlib import contextmanager
from functools import partial
from subprocess import PIPE, Popen
from typing import Callable, Optional, Tuple
from urllib.error import ContentTooShortError, HTTPError, URLError
from urllib.error import ContentTooShortError, HTTPError

import numpy as np
import torch
Expand Down Expand Up @@ -93,17 +94,25 @@ def assert_allclose(
np.testing.assert_allclose(actual, desired, *args, **kwargs)


def test_pretrained_networks(network, input_param, device):
@contextmanager
def skip_if_downloading_fails():
try:
yield
except (ContentTooShortError, HTTPError, ConnectionError) as e:
raise unittest.SkipTest(f"error while downloading: {e}") from e
except RuntimeError as rt_e:
if "network issue" in str(rt_e):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
if "gdown dependency" in str(rt_e): # no gdown installed
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
if "md5 check" in str(rt_e):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
raise rt_e


def test_pretrained_networks(network, input_param, device):
with skip_if_downloading_fails():
return network(**input_param).to(device)
except (URLError, HTTPError) as e:
raise unittest.SkipTest(e) from e
except RuntimeError as r_error:
if "unexpected EOF" in f"{r_error}": # The file might be corrupted.
raise unittest.SkipTest(f"{r_error}") from r_error
if "network issue" in f"{r_error}": # The network is not available.
raise unittest.SkipTest(f"{r_error}") from r_error
raise


def test_is_quick():
Expand Down Expand Up @@ -651,16 +660,8 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0):

def download_url_or_skip_test(*args, **kwargs):
"""``download_url`` and skip the tests if any downloading error occurs."""
try:
with skip_if_downloading_fails():
download_url(*args, **kwargs)
except (ContentTooShortError, HTTPError) as e:
raise unittest.SkipTest(f"error while downloading: {e}") from e
except RuntimeError as rt_e:
if "network issue" in str(rt_e):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
if "gdown dependency" in str(rt_e): # no gdown installed
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
raise rt_e


def query_memory(n=2):
Expand Down