Skip to content

Commit

Permalink
add class method from_remote for CkptMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Apr 1, 2024
1 parent ba11656 commit 4e8ace5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
23 changes: 22 additions & 1 deletion test/test_utils/test_utils_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""

import itertools
import shutil
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -523,11 +524,31 @@ def test_mixin_classes():
assert save_path.is_file()

loaded_model, _ = Model1D.from_checkpoint(save_path)

assert repr(model_1d) == repr(loaded_model)

save_path.unlink()

# test remote un-compressed model
save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_remote_model"
save_path.mkdir(exist_ok=True, parents=True)
loaded_model, _ = Model1D.from_remote(
url="https://www.dropbox.com/scl/fi/5q5q0z0ta48ml0u2xtwm7/test-remote-model.pth?rlkey=2l2erhdnrfc4om0fqarokikb0&dl=1",
model_dir=save_path,
)
assert isinstance(loaded_model, Model1D)
shutil.rmtree(save_path)

# test remote compressed model
save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_remote_model"
save_path.mkdir(exist_ok=True, parents=True)
loaded_model, _ = Model1D.from_remote(
url="https://www.dropbox.com/scl/fi/2eqhnagz1m0w0ka86uegr/test-remote-model.zip?rlkey=1mkuwhx4kykqmc7h4rnou46z0&dl=1",
model_dir=save_path,
compressed=True,
)
assert isinstance(loaded_model, Model1D)
shutil.rmtree(save_path)

model_dummy = ModelDummy(12, CFG(out_channels=128))
assert model_dummy.module_size == model_dummy.sizeof == 0
assert model_dummy.module_size_ == model_dummy.sizeof_ == "0.0B"
Expand Down
42 changes: 41 additions & 1 deletion torch_ecg/utils/utils_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import Tensor, nn

from ..cfg import CFG, DEFAULTS
from .download import http_get
from .misc import add_docstring, make_serializable
from .utils_data import cls_to_bin

Expand Down Expand Up @@ -1043,7 +1044,9 @@ def from_checkpoint(
Parameters
----------
path : `path-like`
Path of the checkpoint.
Path to the checkpoint.
If it is a directory, then this directory should contain only one checkpoint file
(with the extension `.pth` or `.pt`).
device : torch.device, optional
Map location of the model parameters,
defaults to "cuda" if available, otherwise "cpu".
Expand All @@ -1056,6 +1059,10 @@ def from_checkpoint(
Auxiliary configs that are needed for data preprocessing, etc.
"""
if Path(path).is_dir():
candidates = list(Path(path).glob("*.pth")) + list(Path(path).glob("*.pt"))
assert len(candidates) == 1, "The directory should contain only one checkpoint file"
path = candidates[0]
_device = device or DEFAULTS.device
ckpt = torch.load(path, map_location=_device)
aux_config = ckpt.get("train_config", None) or ckpt.get("config", None)
Expand All @@ -1071,6 +1078,39 @@ def from_checkpoint(
model.load_state_dict(ckpt["model_state_dict"])
return model, aux_config

@classmethod
def from_remote(
cls,
url: str,
model_dir: Union[str, bytes, os.PathLike],
filename: Optional[str] = None,
device: Optional[torch.device] = None,
) -> Tuple[nn.Module, dict]:
"""Load the model from the remote model.
Parameters
----------
url : str
URL of the remote model.
model_dir : `path-like`
Path for downloading the model.
filename : str, optional
Filename of the model to save, defaults to the basename of the URL.
device : torch.device, optional
Map location of the model parameters,
defaults to "cuda" if available, otherwise "cpu".
Returns
-------
model : torch.nn.Module
The model loaded from a checkpoint.
aux_config : dict
Auxiliary configs that are needed for data preprocessing, etc.
"""
http_get(url, model_dir, extract=True, filename=filename)
return cls.from_checkpoint(model_dir, device=device)

def save(self, path: Union[str, bytes, os.PathLike], train_config: CFG) -> None:
"""Save the model to disk.
Expand Down

0 comments on commit 4e8ace5

Please sign in to comment.