Skip to content

Commit

Permalink
fix: error in save();
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Nov 29, 2023
1 parent b0e5474 commit 9acb086
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ def save(
)
else:
logger.error(f"File {saving_path} exists. Saving operation aborted.")

try:
create_dir_if_not_exist(saving_path)
create_dir_if_not_exist(saving_dir)
if isinstance(self.device, list):
# to save a DataParallel model generically, save the model.module.state_dict()
torch.save(self.model.module, saving_path)
Expand Down

0 comments on commit 9acb086

Please sign in to comment.