Skip to content

Commit

Permalink
Add warning message for passing ckpt_path that points to a non-existe…
Browse files Browse the repository at this point in the history
…nt file
  • Loading branch information
amorehead committed Oct 8, 2023
1 parent fc265af commit 6709b36
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hydra
import lightning as L
import rootutils
import os
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
Expand Down Expand Up @@ -127,7 +128,14 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
ckpt_path = None
if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")):
ckpt_path = cfg.get("ckpt_path")
elif cfg.get("ckpt_path"):
log.warning(
"`ckpt_path` was given, but the path does not exist. Training with new model weights."
)
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

train_metrics = trainer.callback_metrics

Expand Down

0 comments on commit 6709b36

Please sign in to comment.