Skip to content

Commit

Permalink
Merge 9acb086 into 10408f2
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Nov 29, 2023
2 parents 10408f2 + 9acb086 commit fa27a9b
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 91 deletions.
112 changes: 81 additions & 31 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import Optional, Union
from typing import Optional, Union, Iterable

import torch
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(
model_saving_strategy in saving_strategies
), f"saving_strategy must be one of {saving_strategies}, but got f{model_saving_strategy}."

self.device = None
self.saving_path = saving_path
self.device = None # set up with _setup_device() below
self.saving_path = None # set up with _setup_path() below
self.model_saving_strategy = model_saving_strategy

self.model = None
Expand All @@ -82,7 +82,7 @@ def __init__(
# set up saving_path to save the trained model and training logs
self._setup_path(saving_path)

def _setup_device(self, device: Union[None, str, torch.device, list]):
def _setup_device(self, device: Union[None, str, torch.device, list]) -> None:
if device is None:
# if it is None, then use the first cuda device if cuda is available, otherwise use cpu
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
Expand All @@ -105,7 +105,6 @@ def _setup_device(self, device: Union[None, str, torch.device, list]):
# parallely training on multiple CUDA devices

# ensure the list is not empty

device_list = []
for idx, d in enumerate(device):
if isinstance(d, str):
Expand Down Expand Up @@ -141,7 +140,7 @@ def _setup_device(self, device: Union[None, str, torch.device, list]):
torch.cuda.is_available() and torch.cuda.device_count() > 0
), "You are trying to use CUDA for model training, but CUDA is not available in your environment."

def _setup_path(self, saving_path):
def _setup_path(self, saving_path) -> None:
if isinstance(saving_path, str):
# get the current time to append to saving_path,
# so you can use the same saving_path to run multiple times
Expand All @@ -164,7 +163,7 @@ def _setup_path(self, saving_path):
"saving_path not given. Model files and tensorboard file will not be saved."
)

def _send_model_to_given_device(self):
def _send_model_to_given_device(self) -> None:
if isinstance(self.device, list):
# parallely training on multiple devices
self.model = torch.nn.DataParallel(self.model, device_ids=self.device)
Expand All @@ -175,7 +174,7 @@ def _send_model_to_given_device(self):
else:
self.model = self.model.to(self.device)

def _send_data_to_given_device(self, data):
def _send_data_to_given_device(self, data) -> Iterable:
if isinstance(self.device, torch.device): # single device
data = map(lambda x: x.to(self.device), data)
else: # parallely training on multiple devices
Expand Down Expand Up @@ -214,7 +213,7 @@ def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
) -> None:
"""Automatically save the current model into a file if in need.
Parameters
Expand All @@ -230,17 +229,17 @@ def _auto_save_model_if_necessary(
"""
if self.saving_path is not None and self.model_saving_strategy is not None:
name = self.__class__.__name__ if saving_name is None else saving_name
saving_path = os.path.join(self.saving_path, name)
if not training_finished and self.model_saving_strategy == "better":
self.save_model(self.saving_path, name)
self.save(saving_path)
elif training_finished and self.model_saving_strategy == "best":
self.save_model(self.saving_path, name)
else:
return
self.save(saving_path)
else:
pass

def save_model(
def save(
self,
saving_dir: str,
file_name: str,
saving_path: str,
overwrite: bool = False,
) -> None:
"""Save the model with current parameters to a disk file.
Expand All @@ -251,19 +250,19 @@ def save_model(
Parameters
----------
saving_dir :
The given directory to save the model.
file_name :
The file name of the model to be saved.
saving_path :
The given path to save the model. The directory will be created if it does not exist.
overwrite :
Whether to overwrite the model file if the path already exists.
"""
file_name = (
file_name + ".pypots" if file_name.split(".")[-1] != "pypots" else file_name
)
# split the saving dir and file name from the given path
saving_dir, file_name = os.path.split(saving_path)
# add the suffix ".pypots" if not given
if file_name.split(".")[-1] != "pypots":
file_name += ".pypots"
# rejoin the path for saving the model
saving_path = os.path.join(saving_dir, file_name)

if os.path.exists(saving_path):
Expand All @@ -273,6 +272,7 @@ def save_model(
)
else:
logger.error(f"File {saving_path} exists. Saving operation aborted.")

try:
create_dir_if_not_exist(saving_dir)
if isinstance(self.device, list):
Expand All @@ -286,27 +286,27 @@ def save_model(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)

def load_model(self, model_path: str) -> None:
def load(self, path: str) -> None:
"""Load the saved model from a disk file.
Parameters
----------
model_path :
Local path to a disk file saving trained model.
path :
The local path to a disk file saving the trained model.
Notes
-----
If the training environment and the deploying/test environment use the same type of device (GPU/CPU),
you can load the model directly with torch.load(model_path).
"""
assert os.path.exists(model_path), f"Model file {model_path} does not exist."
assert os.path.exists(path), f"Model file {path} does not exist."

try:
if isinstance(self.device, torch.device):
loaded_model = torch.load(model_path, map_location=self.device)
loaded_model = torch.load(path, map_location=self.device)
else:
loaded_model = torch.load(model_path)
loaded_model = torch.load(path)
if isinstance(loaded_model, torch.nn.Module):
if isinstance(self.device, torch.device):
self.model.load_state_dict(loaded_model.state_dict())
Expand All @@ -316,7 +316,57 @@ def load_model(self, model_path: str) -> None:
self.model = loaded_model.model
except Exception as e:
raise e
logger.info(f"Model loaded successfully from {model_path}.")
logger.info(f"Model loaded successfully from {path}.")

def save_model(
self,
saving_path: str,
overwrite: bool = False,
) -> None:
"""Save the model with current parameters to a disk file.
A ``.pypots`` extension will be appended to the filename if it does not already have one.
Please note that such an extension is not necessary, but to indicate the saved model is from PyPOTS framework
so people can distinguish.
Parameters
----------
saving_path :
The given path to save the model. The directory will be created if it does not exist.
overwrite :
Whether to overwrite the model file if the path already exists.
Warnings
--------
The method save_model is deprecated. Please use `save()` instead.
"""
logger.warning(
"🚨DeprecationWarning: The method save_model is deprecated. Please use `save()` instead."
)
self.save(saving_path, overwrite)

def load_model(self, path: str) -> None:
"""Load the saved model from a disk file.
Parameters
----------
path :
The local path to a disk file saving the trained model.
Notes
-----
If the training environment and the deploying/test environment use the same type of device (GPU/CPU),
you can load the model directly with torch.load(model_path).
Warnings
--------
The method load_model is deprecated. Please use `load()` instead.
"""
logger.warning(
"🚨DeprecationWarning: The method load_model is deprecated. Please use `load()` instead."
)
self.load(path)

@abstractmethod
def fit(
Expand Down
8 changes: 3 additions & 5 deletions tests/classification/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.brits)

# save the trained model into file, and check if the path exists
self.brits.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.load_model(saved_model_path)
self.brits.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/classification/grud.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.grud)

# save the trained model into file, and check if the path exists
self.grud.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.grud.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.grud.load_model(saved_model_path)
self.grud.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/classification/raindrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.raindrop)

# save the trained model into file, and check if the path exists
self.raindrop.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.raindrop.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.raindrop.load_model(saved_model_path)
self.raindrop.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.crli_gru)

# save the trained model into file, and check if the path exists
self.crli_gru.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.crli_gru.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.crli_gru.load_model(saved_model_path)
self.crli_gru.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.vader)

# save the trained model into file, and check if the path exists
self.vader.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.vader.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.vader.load_model(saved_model_path)
self.vader.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.brits)

# save the trained model into file, and check if the path exists
self.brits.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.load_model(saved_model_path)
self.brits.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/csdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.csdi)

# save the trained model into file, and check if the path exists
self.csdi.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csdi.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csdi.load_model(saved_model_path)
self.csdi.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/gpvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.gp_vae)

# save the trained model into file, and check if the path exists
self.gp_vae.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.gp_vae.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.gp_vae.load_model(saved_model_path)
self.gp_vae.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/mrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.mrnn)

# save the trained model into file, and check if the path exists
self.mrnn.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.mrnn.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.mrnn.load_model(saved_model_path)
self.mrnn.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/saits.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.saits)

# save the trained model into file, and check if the path exists
self.saits.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.saits.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.saits.load_model(saved_model_path)
self.saits.load(saved_model_path)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit fa27a9b

Please sign in to comment.