Skip to content

Commit

Permalink
fix(train): fix checkpoint not properly loaded (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Apr 9, 2023
1 parent 0a03035 commit 0979147
Showing 1 changed file with 63 additions and 26 deletions.
89 changes: 63 additions & 26 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import math
import warnings
from logging import getLogger
from pathlib import Path
Expand Down Expand Up @@ -69,12 +68,9 @@ def train(
trainer = pl.Trainer(
logger=TensorBoardLogger(model_path),
# profiler="simple",
val_check_interval=hparams.train.eval_interval,
max_epochs=hparams.train.epochs,
check_val_every_n_epoch=math.ceil(
hparams.train.eval_interval
/ len(datamodule.train_dataset)
* hparams.train.batch_size
),
check_val_every_n_epoch=None,
precision=16
if hparams.train.fp16_run
else "bf16"
Expand All @@ -87,7 +83,11 @@ def train(

class VitsLightning(pl.LightningModule):
def on_train_start(self) -> None:
self.load(False)
self.set_current_epoch(self._temp_epoch)
total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader)
self.set_total_batch_idx(total_batch_idx)
global_step = total_batch_idx * self.optimizers_count
self.set_global_step(global_step)

# check if using tpu
if isinstance(self.trainer.accelerator, TPUAccelerator):
Expand Down Expand Up @@ -140,6 +140,22 @@ def set_global_step(self, global_step: int):
)
assert self.global_step == global_step, f"{self.global_step} != {global_step}"

def set_total_batch_idx(self, total_batch_idx: int):
LOG.info(f"Setting total batch idx to {total_batch_idx}")
self.trainer.fit_loop.epoch_loop.batch_progress.total.ready = (
total_batch_idx + 1
)
self.trainer.fit_loop.epoch_loop.batch_progress.total.completed = (
total_batch_idx
)
assert (
self.total_batch_idx == total_batch_idx + 1
), f"{self.total_batch_idx} != {total_batch_idx + 1}"

@property
def total_batch_idx(self) -> int:
return self.trainer.fit_loop.epoch_loop.total_batch_idx + 1

def load(self, reset_optimizer: bool = False):
latest_g_path = utils.latest_checkpoint_path(self.hparams.model_dir, "G_*.pth")
latest_d_path = utils.latest_checkpoint_path(self.hparams.model_dir, "D_*.pth")
Expand All @@ -157,12 +173,9 @@ def load(self, reset_optimizer: bool = False):
self.optim_d,
reset_optimizer,
)
self.set_current_epoch(epoch)
global_step = epoch * len(self.trainer.train_dataloader)
self.set_global_step(global_step)
assert self.current_epoch == epoch, f"{self.current_epoch} != {epoch}"
self.scheduler_g.last_epoch = self.current_epoch - 1
self.scheduler_d.last_epoch = self.current_epoch - 1
self._temp_epoch = epoch
self.scheduler_g.last_epoch = epoch - 1
self.scheduler_d.last_epoch = epoch - 1
except Exception as e:
raise RuntimeError("Failed to load checkpoint") from e
else:
Expand Down Expand Up @@ -198,6 +211,8 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any):
self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
self.optim_d, gamma=self.hparams.train.lr_decay
)
self.optimizers_count = 2
self.load(reset_optimizer)

def configure_optimizers(self):
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
Expand All @@ -211,7 +226,7 @@ def log_image_dict(
writer: SummaryWriter = self.logger.experiment
for k, v in image_dict.items():
try:
writer.add_image(k, v, self.global_step, dataformats=dataformats)
writer.add_image(k, v, self.total_batch_idx, dataformats=dataformats)
except Exception as e:
warnings.warn(f"Failed to log image {k}: {e}")

Expand All @@ -222,9 +237,25 @@ def log_audio_dict(self, audio_dict: dict[str, Any]) -> None:
writer: SummaryWriter = self.logger.experiment
for k, v in audio_dict.items():
writer.add_audio(
k, v, self.global_step, sample_rate=self.hparams.data.sampling_rate
k,
v,
self.trainer.fit_loop.total_batch_idx,
sample_rate=self.hparams.data.sampling_rate,
)

def log_dict_(self, log_dict: dict[str, Any], **kwargs) -> None:
if not isinstance(self.logger, TensorBoardLogger):
warnings.warn("Logging is only supported with TensorBoardLogger.")
return
writer: SummaryWriter = self.logger.experiment
for k, v in log_dict.items():
writer.add_scalar(k, v, self.total_batch_idx)
kwargs["logger"] = False
self.log_dict(log_dict, **kwargs)

def log_(self, key: str, value: Any, **kwargs) -> None:
self.log_dict_({key: value}, **kwargs)

def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
self.net_g.train()
self.net_d.train()
Expand Down Expand Up @@ -282,9 +313,11 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
loss_gen_all += loss_subband

# log loss
self.log("grad_norm_g", commons.clip_grad_value_(self.net_g.parameters(), None))
self.log("lr", self.optim_g.param_groups[0]["lr"])
self.log_dict(
self.log_(
"grad_norm_g", commons.clip_grad_value_(self.net_g.parameters(), None)
)
self.log_("lr", self.optim_g.param_groups[0]["lr"])
self.log_dict_(
{
"loss/g/total": loss_gen_all,
"loss/g/fm": loss_fm,
Expand All @@ -295,8 +328,8 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
prog_bar=True,
)
if self.hparams.model.get("type_") == "mb-istft":
self.log("loss/g/subband", loss_subband)
if self.global_step % self.hparams.train.log_interval == 0:
self.log_("loss/g/subband", loss_subband)
if self.total_batch_idx % self.hparams.train.log_interval == 0:
self.log_image_dict(
{
"slice/mel_org": utils.plot_spectrogram_to_numpy(
Expand Down Expand Up @@ -338,8 +371,10 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
loss_disc_all = loss_disc

# log loss
self.log("loss/d/total", loss_disc_all, prog_bar=True)
self.log("grad_norm_d", commons.clip_grad_value_(self.net_d.parameters(), None))
self.log_("loss/d/total", loss_disc_all, prog_bar=True)
self.log_(
"grad_norm_d", commons.clip_grad_value_(self.net_d.parameters(), None)
)

# optimizer
self.manual_backward(loss_disc_all)
Expand All @@ -364,19 +399,21 @@ def validation_step(self, batch, batch_idx):
"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
}
)
if self.current_epoch == 0:
return
utils.save_checkpoint(
self.net_g,
self.optim_g,
self.hparams.train.learning_rate,
self.current_epoch,
Path(self.hparams.model_dir) / f"G_{self.global_step}.pth",
self.current_epoch + 1, # prioritize prevention of undervaluation
Path(self.hparams.model_dir) / f"G_{self.total_batch_idx}.pth",
)
utils.save_checkpoint(
self.net_d,
self.optim_d,
self.hparams.train.learning_rate,
self.current_epoch,
Path(self.hparams.model_dir) / f"D_{self.global_step}.pth",
self.current_epoch + 1,
Path(self.hparams.model_dir) / f"D_{self.total_batch_idx}.pth",
)
keep_ckpts = self.hparams.train.get("keep_ckpts", 0)
if keep_ckpts > 0:
Expand Down

0 comments on commit 0979147

Please sign in to comment.