From 9acb086feb15b3b157896b32809a53d55eb9ce9a Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 29 Nov 2023 14:27:37 +0800 Subject: [PATCH] fix: error in save(); --- pypots/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pypots/base.py b/pypots/base.py index 2bd30a9f..3c77899a 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -272,8 +272,9 @@ def save( ) else: logger.error(f"File {saving_path} exists. Saving operation aborted.") + try: - create_dir_if_not_exist(saving_path) + create_dir_if_not_exist(saving_dir) if isinstance(self.device, list): # to save a DataParallel model generically, save the model.module.state_dict() torch.save(self.model.module, saving_path)