Skip to content

Commit

Permalink
fix: fix RTF calculation (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Mar 22, 2023
1 parent 4e45555 commit fb25500
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 35 deletions.
5 changes: 2 additions & 3 deletions src/so_vits_svc_fork/inference/infer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,9 @@ def infer(
predict_f0=auto_predict_f0,
noice_scale=noise_scale,
)[0, 0].data.float()
realtime_coef = len(audio) / (t.elapsed * self.target_sample)
audio_duration = audio.shape[-1] / self.target_sample
LOG.info(
f"Inferece time: {t.elapsed:.2f}s, Realtime coef: {realtime_coef:.2f} "
f"Input shape: {audio.shape}, Output shape: {audio.shape}"
f"Inferece time: {t.elapsed:.2f}s, RTF: {t.elapsed / audio_duration:.2f}"
)
return audio, audio.shape[-1]

Expand Down
9 changes: 6 additions & 3 deletions src/so_vits_svc_fork/inference_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def realtime(
f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}"
)

# the model realtime coef is somewhat significantly low only in the first inference
# the model RTL is somewhat significantly high only in the first inference
# there could be no better way to warm up the model than to do a dummy inference
# (there are not differences in the behavior of the model between the first and the later inferences)
# so we do a dummy inference to warm up the model (1 second of audio)
Expand Down Expand Up @@ -211,7 +211,10 @@ def callback(
outdata[:] = (indata + inference) / 2
else:
outdata[:] = inference
LOG.info(f"True Realtime coef: {block_seconds / t.elapsed:.2f}")
rtf = t.elapsed / block_seconds
LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
if rtf > 1:
LOG.warning("RTF is too high, consider increasing block_seconds")

with sd.Stream(
device=(input_device, output_device),
Expand All @@ -221,6 +224,6 @@ def callback(
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
LOG.info(f"Latency: {stream.latency}")
sd.sleep(1000)
69 changes: 40 additions & 29 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import requests
import torch
import torchcrepe
from cm_time import timer
from numpy import dtype, float32, ndarray
from scipy.io.wavfile import read
from torch import FloatTensor, Tensor
Expand Down Expand Up @@ -245,20 +246,24 @@ def compute_f0(
method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "crepe",
**kwargs,
):
wav_numpy = wav_numpy.astype(np.float32)
wav_numpy /= np.quantile(np.abs(wav_numpy), 0.999)
if method in ["dio", "harvest"]:
return compute_f0_pyworld(wav_numpy, p_len, sampling_rate, hop_length, method)
elif method == "crepe":
return compute_f0_crepe(wav_numpy, p_len, sampling_rate, hop_length, **kwargs)
elif method == "crepe-tiny":
return compute_f0_crepe(
wav_numpy, p_len, sampling_rate, hop_length, model="tiny", **kwargs
)
elif method == "parselmouth":
return compute_f0_parselmouth(wav_numpy, p_len, sampling_rate, hop_length)
else:
raise ValueError("type must be dio, crepe, harvest or parselmouth")
with timer() as t:
wav_numpy = wav_numpy.astype(np.float32)
wav_numpy /= np.quantile(np.abs(wav_numpy), 0.999)
if method in ["dio", "harvest"]:
f0 = compute_f0_pyworld(wav_numpy, p_len, sampling_rate, hop_length, method)
elif method == "crepe":
f0 = compute_f0_crepe(wav_numpy, p_len, sampling_rate, hop_length, **kwargs)
elif method == "crepe-tiny":
f0 = compute_f0_crepe(
wav_numpy, p_len, sampling_rate, hop_length, model="tiny", **kwargs
)
elif method == "parselmouth":
f0 = compute_f0_parselmouth(wav_numpy, p_len, sampling_rate, hop_length)
else:
raise ValueError("type must be dio, crepe, harvest or parselmouth")
rtf = t.elapsed / (len(wav_numpy) / sampling_rate)
LOG.info(f"F0 inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
return f0


def f0_to_coarse(f0: torch.Tensor | float):
Expand Down Expand Up @@ -338,21 +343,27 @@ def get_hubert_model():


def get_hubert_content(hmodel, wav_16k_tensor):
feats = wav_16k_tensor
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.to(wav_16k_tensor.device),
"padding_mask": padding_mask.to(wav_16k_tensor.device),
"output_layer": 9, # layer 9
}
with torch.no_grad():
logits = hmodel.extract_features(**inputs)
feats = hmodel.final_proj(logits[0])
return feats.transpose(1, 2)
with timer() as t:
feats = wav_16k_tensor
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.to(wav_16k_tensor.device),
"padding_mask": padding_mask.to(wav_16k_tensor.device),
"output_layer": 9, # layer 9
}
with torch.no_grad():
logits = hmodel.extract_features(**inputs)
feats = hmodel.final_proj(logits[0])
res = feats.transpose(1, 2)
wav_len = wav_16k_tensor.shape[-1] / 16000
LOG.info(
f"HuBERT inference time : {t.elapsed:.3f}s, RTF: {t.elapsed / wav_len:.3f}"
)
return res


def get_content(cmodel: Any, y: ndarray) -> ndarray:
Expand Down

0 comments on commit fb25500

Please sign in to comment.