diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index cb3ae14885..1df6d74f9d 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -11,4 +11,4 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar -from .utils import check_hash, download_and_extract, download_url, extractall, get_logger, logger +from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 2961f39720..ad66f33490 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -22,7 +22,7 @@ from urllib.error import ContentTooShortError, HTTPError, URLError from urllib.request import urlretrieve -from monai.utils import min_version, optional_import +from monai.utils import look_up_option, min_version, optional_import gdown, has_gdown = optional_import("gdown", "3.6") @@ -33,9 +33,10 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") -__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger"] +__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger", "SUPPORTED_HASH_TYPES"] DEFAULT_FMT = "%(asctime)s - %(levelname)s - %(message)s" +SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512} def get_logger( @@ -117,18 +118,16 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") Args: filepath: path of source file to verify hash value. val: expected hash value of the file. - hash_type: 'md5' or 'sha1', defaults to 'md5'. + hash_type: type of hash algorithm to use, default is `"md5"`. + The supported hash types are `"md5"`, `"sha1"`, `"sha256"`, `"sha512"`. + See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`. """ if val is None: logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") return True - if hash_type.lower() == "md5": - actual_hash = hashlib.md5() - elif hash_type.lower() == "sha1": - actual_hash = hashlib.sha1() - else: - raise NotImplementedError(f"Unknown 'hash_type' {hash_type}.") + actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES) + actual_hash = actual_hash_func() try: with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): diff --git a/tests/clang_format_utils.py b/tests/clang_format_utils.py index 1391fdcd47..22f86c50b9 100644 --- a/tests/clang_format_utils.py +++ b/tests/clang_format_utils.py @@ -34,8 +34,8 @@ # This dictionary maps each platform to a relative path to a file containing its reference hash. # github/pytorch/pytorch/tree/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_hash PLATFORM_TO_HASH = { - "Darwin": "b24cc8972344c4e01afbbae78d6a414f7638ff6f", - "Linux": "9073602de1c4e1748f2feea5a0782417b20e3043", + "Darwin": "1485a242a96c737ba7cdd9f259114f2201accdb46d87ac7a8650b1a814cd4d4d", + "Linux": "e1c8b97b919541a99e0a355df5c3f9e8abebc64259dbee6f8c68e1ef90582856", } # Directory and file paths for the clang-format binary. @@ -58,7 +58,7 @@ def get_and_check_clang_format(): try: download_url( - PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha1" + PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha256" ) except Exception as e: print(f"Download {CLANG_FORMAT_PATH} failed: {e}") diff --git a/tests/test_check_hash.py b/tests/test_check_hash.py index 0126b3c1a3..8009302ad0 100644 --- a/tests/test_check_hash.py +++ b/tests/test_check_hash.py @@ -41,7 +41,7 @@ def test_result(self, md5_value, t, expected_result): self.assertTrue(result == expected_result) def test_hash_type_error(self): - with self.assertRaises(NotImplementedError): + with self.assertRaises(ValueError): with tempfile.TemporaryDirectory() as tempdir: check_hash(tempdir, "test_hash", "test_type")