diff --git a/test/test_utils/test_utils_nn.py b/test/test_utils/test_utils_nn.py index ec533707..7207ca26 100644 --- a/test/test_utils/test_utils_nn.py +++ b/test/test_utils/test_utils_nn.py @@ -2,6 +2,7 @@ """ import itertools +import shutil from pathlib import Path import numpy as np @@ -523,11 +524,31 @@ def test_mixin_classes(): assert save_path.is_file() loaded_model, _ = Model1D.from_checkpoint(save_path) - assert repr(model_1d) == repr(loaded_model) save_path.unlink() + # test remote un-compressed model + save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_remote_model" + save_path.mkdir(exist_ok=True, parents=True) + loaded_model, _ = Model1D.from_remote( + url="https://www.dropbox.com/scl/fi/5q5q0z0ta48ml0u2xtwm7/test-remote-model.pth?rlkey=2l2erhdnrfc4om0fqarokikb0&dl=1", + model_dir=save_path, + ) + assert isinstance(loaded_model, Model1D) + shutil.rmtree(save_path) + + # test remote compressed model + save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_remote_model" + save_path.mkdir(exist_ok=True, parents=True) + loaded_model, _ = Model1D.from_remote( + url="https://www.dropbox.com/scl/fi/2eqhnagz1m0w0ka86uegr/test-remote-model.zip?rlkey=1mkuwhx4kykqmc7h4rnou46z0&dl=1", + model_dir=save_path, + compressed=True, + ) + assert isinstance(loaded_model, Model1D) + shutil.rmtree(save_path) + model_dummy = ModelDummy(12, CFG(out_channels=128)) assert model_dummy.module_size == model_dummy.sizeof == 0 assert model_dummy.module_size_ == model_dummy.sizeof_ == "0.0B" diff --git a/torch_ecg/utils/utils_nn.py b/torch_ecg/utils/utils_nn.py index 00ba397d..73cbb978 100644 --- a/torch_ecg/utils/utils_nn.py +++ b/torch_ecg/utils/utils_nn.py @@ -18,6 +18,7 @@ from torch import Tensor, nn from ..cfg import CFG, DEFAULTS +from .download import http_get from .misc import add_docstring, make_serializable from .utils_data import cls_to_bin @@ -1043,7 +1044,9 @@ def from_checkpoint( Parameters ---------- path : `path-like` - Path of the checkpoint. + Path to the checkpoint. + If it is a directory, then this directory should contain only one checkpoint file + (with the extension `.pth` or `.pt`). device : torch.device, optional Map location of the model parameters, defaults to "cuda" if available, otherwise "cpu". @@ -1056,6 +1059,10 @@ def from_checkpoint( Auxiliary configs that are needed for data preprocessing, etc. """ + if Path(path).is_dir(): + candidates = list(Path(path).glob("*.pth")) + list(Path(path).glob("*.pt")) + assert len(candidates) == 1, "The directory should contain only one checkpoint file" + path = candidates[0] _device = device or DEFAULTS.device ckpt = torch.load(path, map_location=_device) aux_config = ckpt.get("train_config", None) or ckpt.get("config", None) @@ -1071,6 +1078,39 @@ def from_checkpoint( model.load_state_dict(ckpt["model_state_dict"]) return model, aux_config + @classmethod + def from_remote( + cls, + url: str, + model_dir: Union[str, bytes, os.PathLike], + filename: Optional[str] = None, + device: Optional[torch.device] = None, + ) -> Tuple[nn.Module, dict]: + """Load the model from the remote model. + + Parameters + ---------- + url : str + URL of the remote model. + model_dir : `path-like` + Path for downloading the model. + filename : str, optional + Filename of the model to save, defaults to the basename of the URL. + device : torch.device, optional + Map location of the model parameters, + defaults to "cuda" if available, otherwise "cpu". + + Returns + ------- + model : torch.nn.Module + The model loaded from a checkpoint. + aux_config : dict + Auxiliary configs that are needed for data preprocessing, etc. + + """ + http_get(url, model_dir, extract=True, filename=filename) + return cls.from_checkpoint(model_dir, device=device) + def save(self, path: Union[str, bytes, os.PathLike], train_config: CFG) -> None: """Save the model to disk.