Skip to content

Commit

Permalink
fix: fix issues when loading legacy checkpoint and fix pre-hubert n_j…
Browse files Browse the repository at this point in the history
…obs (#236)
  • Loading branch information
34j committed Apr 5, 2023
1 parent d561edb commit 15f1e7f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"learning_rate": 0.0001,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 18,
"batch_size": 12,
"fp16_run": false,
"lr_decay": 0.999875,
"segment_size": 10240,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"learning_rate": 0.0001,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 6,
"batch_size": 18,
"fp16_run": false,
"lr_decay": 0.999875,
"segment_size": 10240,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"learning_rate": 0.0001,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 18,
"batch_size": 12,
"fp16_run": false,
"lr_decay": 0.999875,
"segment_size": 10240,
Expand Down
22 changes: 14 additions & 8 deletions src/so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torchaudio
from fairseq.models.hubert import HubertModel
from joblib import Parallel, delayed
from joblib import Parallel, cpu_count, delayed
from tqdm import tqdm

import so_vits_svc_fork.f0
Expand All @@ -22,8 +22,8 @@
from .preprocess_utils import check_hubert_min_duration

LOG = getLogger(__name__)
HUBERT_MEMORY = 1600
HUBERT_MEMORY_CREPE = 2600
HUBERT_MEMORY = 2900
HUBERT_MEMORY_CREPE = 3900


def _process_one(
Expand Down Expand Up @@ -124,11 +124,17 @@ def preprocess_hubert_f0(
utils.ensure_pretrained_model(".", "contentvec")
hps = utils.get_hparams(config_path)
if n_jobs is None:
memory = get_total_gpu_memory("free")
n_jobs = (
memory // (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY)
if memory is not None
else 1
# add cpu_count() to avoid SIGKILL
memory = get_total_gpu_memory("total")
n_jobs = min(
max(
memory
// (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY)
if memory is not None
else 1,
1,
),
cpu_count(),
)
LOG.info(f"n_jobs automatically set to {n_jobs}, memory: {memory} MiB")

Expand Down
63 changes: 37 additions & 26 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,32 @@ def get_content(
return c


def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> None:
for k, v in from_.items():
if k not in to_:
warnings.warn(f"Key {k} not found in model state dict")
elif hasattr(v, "shape"):
if not hasattr(to_[k], "shape"):
raise ValueError(f"Key {k} is not a tensor")
if to_[k].shape == v.shape:
to_[k] = v
else:
warnings.warn(
f"Shape mismatch for key {k}, {to_[k].shape} != {v.shape}"
)
elif isinstance(v, dict):
assert isinstance(to_[k], dict)
_substitute_if_same_shape(to_[k], v)
else:
to_[k] = v


def safe_load(model: torch.nn.Module, state_dict: dict[str, Any]) -> None:
model_state_dict = model.state_dict()
_substitute_if_same_shape(model_state_dict, state_dict)
model.load_state_dict(model_state_dict)


def load_checkpoint(
checkpoint_path: Path | str,
model: torch.nn.Module,
Expand All @@ -174,37 +200,22 @@ def load_checkpoint(
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]

# safe load module
if hasattr(model, "module"):
safe_load(model.module, checkpoint_dict["model"])
else:
safe_load(model, checkpoint_dict["model"])
# safe load optim
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
except Exception as e:
LOG.exception(e)
LOG.warning("Failed to load optimizer state")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except Exception as e:
LOG.exception(e)
LOG.error("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
safe_load(optimizer, checkpoint_dict["optimizer"])

LOG.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})")
return model, optimizer, learning_rate, iteration

Expand Down

0 comments on commit 15f1e7f

Please sign in to comment.