diff --git a/pypots/base.py b/pypots/base.py index 2bd30a9f..3c77899a 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -272,8 +272,9 @@ def save( ) else: logger.error(f"File {saving_path} exists. Saving operation aborted.") + try: - create_dir_if_not_exist(saving_path) + create_dir_if_not_exist(saving_dir) if isinstance(self.device, list): # to save a DataParallel model generically, save the model.module.state_dict() torch.save(self.model.module, saving_path)