diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py index f54c30b6d..168a78ad8 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py @@ -17,6 +17,7 @@ import logging import os from abc import ABC +from pathlib import Path import torch from torch import distributed @@ -44,6 +45,9 @@ def __init__(self, params=None, json_path=None): self.params = get_from_dicts(params, default_parameters) self.params = get_from_json(json_path, self.params) self._sanity_check() + Path(self.params['output'] + ['save_model_dir']).expanduser().resolve().mkdir(parents=True, + exist_ok=True) logging.info("Model parameters : %s", self.params) self.input_type = self.params['input']['type']