Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset
from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
from .utils import check_hash, download_and_extract, download_url, extractall, get_logger, logger
from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger
17 changes: 8 additions & 9 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from urllib.error import ContentTooShortError, HTTPError, URLError
from urllib.request import urlretrieve

from monai.utils import min_version, optional_import
from monai.utils import look_up_option, min_version, optional_import

gdown, has_gdown = optional_import("gdown", "3.6")

Expand All @@ -33,9 +33,10 @@
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger"]
__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger", "SUPPORTED_HASH_TYPES"]

DEFAULT_FMT = "%(asctime)s - %(levelname)s - %(message)s"
SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512}


def get_logger(
Expand Down Expand Up @@ -117,18 +118,16 @@ def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5")
Args:
filepath: path of source file to verify hash value.
val: expected hash value of the file.
hash_type: 'md5' or 'sha1', defaults to 'md5'.
hash_type: type of hash algorithm to use, default is `"md5"`.
The supported hash types are `"md5"`, `"sha1"`, `"sha256"`, `"sha512"`.
See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`.

"""
if val is None:
logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
return True
if hash_type.lower() == "md5":
actual_hash = hashlib.md5()
elif hash_type.lower() == "sha1":
actual_hash = hashlib.sha1()
else:
raise NotImplementedError(f"Unknown 'hash_type' {hash_type}.")
actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)
actual_hash = actual_hash_func()
try:
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
Expand Down
6 changes: 3 additions & 3 deletions tests/clang_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
# This dictionary maps each platform to a relative path to a file containing its reference hash.
# github/pytorch/pytorch/tree/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_hash
PLATFORM_TO_HASH = {
"Darwin": "b24cc8972344c4e01afbbae78d6a414f7638ff6f",
"Linux": "9073602de1c4e1748f2feea5a0782417b20e3043",
"Darwin": "1485a242a96c737ba7cdd9f259114f2201accdb46d87ac7a8650b1a814cd4d4d",
"Linux": "e1c8b97b919541a99e0a355df5c3f9e8abebc64259dbee6f8c68e1ef90582856",
}

# Directory and file paths for the clang-format binary.
Expand All @@ -58,7 +58,7 @@ def get_and_check_clang_format():

try:
download_url(
PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha1"
PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type="sha256"
)
except Exception as e:
print(f"Download {CLANG_FORMAT_PATH} failed: {e}")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_check_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_result(self, md5_value, t, expected_result):
self.assertTrue(result == expected_result)

def test_hash_type_error(self):
with self.assertRaises(NotImplementedError):
with self.assertRaises(ValueError):
with tempfile.TemporaryDirectory() as tempdir:
check_hash(tempdir, "test_hash", "test_type")

Expand Down