From b8d520b8892404896ebe9ba19808dd7f55fd7cb1 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Mon, 18 Sep 2023 23:47:15 +0000 Subject: [PATCH] fix(finetune): split training into three parts 1) training from scratch 2) resuming from a checkpoint without changes (preserves epoch and current step) and 3) fine-tuning by changing values in the training configuration --- everyvoice/base_cli/helpers.py | 31 +++++++++++++++++++++++++++++-- requirements.txt | 1 + 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/everyvoice/base_cli/helpers.py b/everyvoice/base_cli/helpers.py index b4d6502f..18629a7b 100644 --- a/everyvoice/base_cli/helpers.py +++ b/everyvoice/base_cli/helpers.py @@ -7,8 +7,10 @@ import os from enum import Enum from pathlib import Path +from pprint import pformat from typing import List, Optional, Union +from deepdiff import DeepDiff from loguru import logger from tqdm import tqdm @@ -147,8 +149,33 @@ def train_base_command( and os.path.exists(config.training.finetune_checkpoint) else None ) - tensorboard_logger.log_hyperparams(config.dict()) - trainer.fit(model_obj, data, ckpt_path=last_ckpt) + # Train from Scratch + if last_ckpt is None: + model_obj = model(config) + tensorboard_logger.log_hyperparams(config.dict()) + trainer.fit(model_obj, data) + else: + model_obj = model.load_from_checkpoint(last_ckpt) + # Check if the trainer has changed (but ignore subdir since it is specific to the run) + diff = DeepDiff(model_obj.config.training.dict(), config.training.dict()) + training_config_diff = [ + item for item in diff["values_changed"].items() if "sub_dir" not in item[0] + ] + if training_config_diff: + model_obj.config.training = config.training + tensorboard_logger.log_hyperparams(config.dict()) + # Finetune from Checkpoint + logger.warning( + f"""Some of your training hyperparameters have changed from your checkpoint at '{last_ckpt}', so we will override your checkpoint hyperparameters. + Your training logs will start from epoch 0/step 0, but will still use the weights from your checkpoint. Values Changed: {pformat(training_config_diff)} + """ + ) + trainer.fit(model_obj, data) + else: + logger.info(f"Resuming from checkpoint '{last_ckpt}'") + # Resume from checkpoint + tensorboard_logger.log_hyperparams(config.dict()) + trainer.fit(model_obj, data, ckpt_path=last_ckpt) def inference_base_command(name: Enum): diff --git a/requirements.txt b/requirements.txt index d6324b76..a1d61b22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ clipdetect>=0.1.3 +deepdiff>=6.5.0 anytree>=2.8.0 einops==0.5.0 g2p>=1.0.20230417