Skip to content

Commit

Permalink
feat(train): automatically decide batch_size (#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Apr 15, 2023
1 parent 3149469 commit 8ffa128
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ svc train -t

#### Notes

- Dataset audio duration per file should be <~ 10s or VRAM will run out.
- It is recommended to increase the `batch_size` as much as possible in `config.json` before the `train` command to match the VRAM capacity.
- Dataset audio duration per file should be <~ 10s.
- It is recommended to increase the `batch_size` as much as possible in `config.json` before the `train` command to match the VRAM capacity. Setting `batch_size` to `auto-{init_batch_size}-{max_n_trials}` (or simply `auto`) will automatically increase `batch_size` until OOM error occurs, but may not be useful in some cases.
- To use `CREPE`, replace `svc pre-hubert` with `svc pre-hubert -fm crepe`.
- To use `QuickVC`, replace `svc pre-config` with `svc pre-config -t quickvc`.
- Silence removal and volume normalization are automatically performed (as in the upstream repo) and are not required.
Expand Down
63 changes: 52 additions & 11 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
Expand All @@ -33,9 +34,14 @@


class VCDataModule(pl.LightningDataModule):
batch_size: int

def __init__(self, hparams: Any):
super().__init__()
self.__hparams = hparams
self.batch_size = hparams.train.batch_size
if not isinstance(self.batch_size, int):
self.batch_size = 1
self.collate_fn = TextAudioCollate()

# these should be called in setup(), but we need to calculate check_val_every_n_epoch
Expand All @@ -47,7 +53,7 @@ def train_dataloader(self):
self.train_dataset,
# pin_memory=False,
num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 4)),
batch_size=self.__hparams.train.batch_size,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
)

Expand Down Expand Up @@ -90,7 +96,37 @@ def train(
strategy=strategy,
callbacks=[pl.callbacks.RichProgressBar()] if not is_notebook() else None,
)
tuner = Tuner(trainer)
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)

# automatic batch size scaling
batch_size = hparams.train.batch_size
batch_split = str(batch_size).split("-")
batch_size = batch_split[0]
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
if batch_size == "auto":
batch_size = "binsearch"
if batch_size in ["power", "binsearch"]:
model.tuning = True
tuner.scale_batch_size(
model,
mode=batch_size,
datamodule=datamodule,
steps_per_trial=1,
init_val=init_val,
max_trials=max_trials,
)
model.tuning = False
else:
batch_size = int(batch_size)
# automatic learning rate scaling is not supported for multiple optimizers
"""if hparams.train.learning_rate == "auto":
lr_finder = tuner.lr_find(model)
LOG.info(lr_finder.results)
fig = lr_finder.plot(suggest=True)
fig.savefig(model_path / "lr_finder.png")"""

trainer.fit(model, datamodule=datamodule)


Expand All @@ -108,15 +144,16 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any):
)
self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm)
self.automatic_optimization = False
self.learning_rate = self.hparams.train.learning_rate
self.optim_g = torch.optim.AdamW(
self.net_g.parameters(),
self.hparams.train.learning_rate,
self.learning_rate,
betas=self.hparams.train.betas,
eps=self.hparams.train.eps,
)
self.optim_d = torch.optim.AdamW(
self.net_d.parameters(),
self.hparams.train.learning_rate,
self.learning_rate,
betas=self.hparams.train.betas,
eps=self.hparams.train.eps,
)
Expand All @@ -128,13 +165,15 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any):
)
self.optimizers_count = 2
self.load(reset_optimizer)
self.tuning = False

def on_train_start(self) -> None:
self.set_current_epoch(self._temp_epoch)
total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader)
self.set_total_batch_idx(total_batch_idx)
global_step = total_batch_idx * self.optimizers_count
self.set_global_step(global_step)
if not self.tuning:
self.set_current_epoch(self._temp_epoch)
total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader)
self.set_total_batch_idx(total_batch_idx)
global_step = total_batch_idx * self.optimizers_count
self.set_global_step(global_step)

# check if using tpu
if isinstance(self.trainer.accelerator, TPUAccelerator):
Expand Down Expand Up @@ -397,7 +436,9 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
)

accumulate_grad_batches = self.hparams.train.get("accumulate_grad_batches", 1)
should_update = (batch_idx + 1) % accumulate_grad_batches == 0
should_update = (
batch_idx + 1
) % accumulate_grad_batches == 0 or self.trainer.is_last_batch
# optimizer
self.manual_backward(loss_gen_all / accumulate_grad_batches)
if should_update:
Expand Down Expand Up @@ -462,14 +503,14 @@ def validation_step(self, batch, batch_idx):
utils.save_checkpoint(
self.net_g,
self.optim_g,
self.hparams.train.learning_rate,
self.learning_rate,
self.current_epoch + 1, # prioritize prevention of undervaluation
Path(self.hparams.model_dir) / f"G_{self.total_batch_idx}.pth",
)
utils.save_checkpoint(
self.net_d,
self.optim_d,
self.hparams.train.learning_rate,
self.learning_rate,
self.current_epoch + 1,
Path(self.hparams.model_dir) / f"D_{self.total_batch_idx}.pth",
)
Expand Down

0 comments on commit 8ffa128

Please sign in to comment.