Skip to content

Commit

Permalink
add download tests for Caltech(101|256) (pytorch#2731)
Browse files Browse the repository at this point in the history
* add download tests for Caltech(101|256)

* lint
  • Loading branch information
pmeier authored and vfdev-5 committed Dec 4, 2020
1 parent 37db3e7 commit 77aff2a
Showing 1 changed file with 52 additions and 13 deletions.
65 changes: 52 additions & 13 deletions test/test_datasets_download.py
Expand Up @@ -45,13 +45,23 @@ def inner_wrapper(request, *args, **kwargs):


@contextlib.contextmanager
def log_download_attempts(patch=True):
urls_and_md5s = set()
with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock:
def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None):
if urls_and_md5s is None:
urls_and_md5s = set()
if patch_auxiliaries is None:
patch_auxiliaries = patch

with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url)
)
if patch_auxiliaries:
# download_and_extract_archive
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
try:
yield urls_and_md5s
finally:
for args, kwargs in mock.call_args_list:
for args, kwargs in download_url_mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
Expand Down Expand Up @@ -105,15 +115,14 @@ def __init__(self, url, md5=None, id=None):
self.md5 = md5
self.id = id or url

def __repr__(self):
return self.id

def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)

return dict(argnames="url, md5", argvalues=argvalues, ids=ids)
def make_download_configs(urls_and_md5s, name=None):
return [
DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
]


def places365():
Expand All @@ -124,10 +133,40 @@ def places365():

datasets.Places365(root, split=split, small=small, download=True)

return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s]
return make_download_configs(urls_and_md5s, "Places365")


def caltech101():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech101(".", download=True)
except Exception:
pass

return make_download_configs(urls_and_md5s, "Caltech101")


def caltech256():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech256(".", download=True)
except Exception:
pass

return make_download_configs(urls_and_md5s, "Caltech256")


def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)

return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)


@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),)))
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))

Expand Down

0 comments on commit 77aff2a

Please sign in to comment.