Skip to content

Commit

Permalink
fix(trainer): load checkpoint without ckpt_path which doesn't allow f…
Browse files Browse the repository at this point in the history
…or updated hyperparameters

fixes #41
  • Loading branch information
roedoejet committed Sep 11, 2023
1 parent bd2ad72 commit 4d83b1c
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions everyvoice/base_cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from enum import Enum
from pathlib import Path
from tempfile import TemporaryDirectory

Check warning on line 10 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L10

Added line #L10 was not covered by tests
from typing import List, Optional, Union

from loguru import logger
Expand Down Expand Up @@ -167,16 +168,29 @@ def train_base_command(
num_nodes=nodes,
detect_anomaly=False, # used for debugging, but triples training time
)
model_obj = model(config)

data = data_module(config) # type: ignore
last_ckpt = (
config.training.finetune_checkpoint
if config.training.finetune_checkpoint is not None
and os.path.exists(config.training.finetune_checkpoint)
else None
)
tensorboard_logger.log_hyperparams(config.dict())
trainer.fit(model_obj, data, ckpt_path=last_ckpt)
with TemporaryDirectory() as tmpdir:

Check warning on line 179 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L179

Added line #L179 was not covered by tests
if last_ckpt is not None:
import torch

Check warning on line 181 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L181

Added line #L181 was not covered by tests

model_obj = torch.load(last_ckpt)
model_obj["hyper_parameters"]["config"].training = config.training # type: ignore

Check warning on line 184 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L183-L184

Added lines #L183 - L184 were not covered by tests
# This is silly but seems to be the most straightforward way of
# not losing the current epoch https://github.com/Lightning-AI/lightning/issues/12819#issuecomment-1644018988
last_ckpt = Path(tmpdir) / "new.ckpt"

Check warning on line 187 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L187

Added line #L187 was not covered by tests
# save modified file
torch.save(model_obj, last_ckpt)

Check warning on line 189 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L189

Added line #L189 was not covered by tests
# delete model from memory
del model_obj
tensorboard_logger.log_hyperparams(config.dict())
trainer.fit(model(config), data, ckpt_path=last_ckpt)

Check warning on line 193 in everyvoice/base_cli/helpers.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/base_cli/helpers.py#L191-L193

Added lines #L191 - L193 were not covered by tests


def inference_base_command(name: Enum):
Expand Down

0 comments on commit 4d83b1c

Please sign in to comment.