Skip to content

Commit

Permalink
fix(train): set persistent_workers = True in DataLoader for performan…
Browse files Browse the repository at this point in the history
…ce, do not save checkpoints, fix logging issue and multiple warning issues, do not do validation when global_step == 0 (#384)
  • Loading branch information
34j committed Apr 18, 2023
1 parent 924342f commit 6cab9af
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
3 changes: 1 addition & 2 deletions src/so_vits_svc_fork/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
if IS_TEST:
LOG.debug("Test mode is on.")

LOG.info(f"Version: {__version__}")


class RichHelpFormatter(click.HelpFormatter):
def __init__(
Expand All @@ -31,6 +29,7 @@ def __init__(
) -> None:
width = 100
super().__init__(indent_increment, width, max_width)
LOG.info(f"Version: {__version__}")


def patch_wrap_text():
Expand Down
2 changes: 2 additions & 0 deletions src/so_vits_svc_fork/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pebble import ProcessFuture, ProcessPool
from tqdm.tk import tqdm_tk

from . import __version__
from .utils import ensure_pretrained_model, get_optimal_device

GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json"
Expand Down Expand Up @@ -97,6 +98,7 @@ def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path:


def main():
LOG.info(f"version: {__version__}")
try:
ensure_pretrained_model(".", "contentvec", tqdm_cls=tqdm_tk)
except Exception as e:
Expand Down
4 changes: 0 additions & 4 deletions src/so_vits_svc_fork/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
import warnings
from logging import (
DEBUG,
INFO,
Expand Down Expand Up @@ -36,9 +35,6 @@ def init_logger() -> None:
if IS_TEST:
getLogger(package_name).setLevel(DEBUG)
captureWarnings(True)
warnings.filterwarnings(
"ignore", category=UserWarning, message="TypedStorage is deprecated"
)
LOGGER_INIT = True


Expand Down
11 changes: 7 additions & 4 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from .modules.synthesizers import SynthesizerTrn

LOG = getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")


Expand All @@ -51,16 +50,15 @@ def __init__(self, hparams: Any):
def train_dataloader(self):
return DataLoader(
self.train_dataset,
# pin_memory=False,
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 4)),
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
batch_size=self.batch_size,
collate_fn=self.collate_fn,
persistent_workers=True,
)

def val_dataloader(self):
return DataLoader(
self.val_dataset,
# pin_memory=False,
batch_size=1,
collate_fn=self.collate_fn,
)
Expand Down Expand Up @@ -95,6 +93,8 @@ def train(
else 32,
strategy=strategy,
callbacks=[pl.callbacks.RichProgressBar()] if not is_notebook() else None,
benchmark=True,
enable_checkpointing=False,
)
tuner = Tuner(trainer)
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
Expand Down Expand Up @@ -482,6 +482,9 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
self.scheduler_d.step()

def validation_step(self, batch, batch_idx):
# avoid logging with wrong global step
if self.global_step == 0:
return
with torch.no_grad():
self.net_g.eval()
c, f0, _, mel, y, g, _, uv = batch
Expand Down
8 changes: 6 additions & 2 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> Non
shape_missmatch = []
for k, v in from_.items():
if k not in to_:
warnings.warn(f"Key {k} not found in model state dict")
pass
elif hasattr(v, "shape"):
if not hasattr(to_[k], "shape"):
raise ValueError(f"Key {k} is not a tensor")
Expand Down Expand Up @@ -225,7 +225,11 @@ def load_checkpoint(
if not Path(checkpoint_path).is_file():
raise FileNotFoundError(f"File {checkpoint_path} not found")
with Path(checkpoint_path).open("rb") as f:
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=UserWarning, message="TypedStorage is deprecated"
)
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]

Expand Down

0 comments on commit 6cab9af

Please sign in to comment.