In [19]:
import logging
import sys,os
from pathlib import Path

from scipy.fftpack import shift
import torch
import argparse
import numpy as np

from omegaconf import OmegaConf
from scipy.io.wavfile import write
from vits.models import SynthesizerInfer
from pitch import load_csv_pitch
from feature_retrieval import IRetrieval, DummyRetrieval, FaissIndexRetrieval, load_retrieve_index

logger = logging.getLogger(__name__)

os.environ['USE_FLASH_ATTENTION']='1'
def get_speaker_name_from_path(speaker_path: Path) -> str:
    suffixes = "".join(speaker_path.suffixes)
    filename = speaker_path.name
    return filename.rstrip(suffixes)


def create_retrival(cli_args) -> IRetrieval:
    if not cli_args.enable_retrieval:
        logger.info("infer without retrival")
        return DummyRetrieval()
    else:
        logger.info("load index retrival model")

    speaker_name = get_speaker_name_from_path(Path(args.spk))
    base_path = Path(".").absolute() / "data_svc" / "indexes" / speaker_name
    print('base_path',base_path)
    if cli_args.hubert_index_path:
        hubert_index_filepath = cli_args.hubert_index_path
    else:
        index_name = f"{cli_args.retrieval_index_prefix}hubert.index"
        hubert_index_filepath = base_path / index_name

    if cli_args.whisper_index_path:
        whisper_index_filepath = cli_args.whisper_index_path
    else:
        index_name = f"{cli_args.retrieval_index_prefix}whisper.index"
        whisper_index_filepath = base_path / index_name

    return FaissIndexRetrieval(
        hubert_index=load_retrieve_index(
            filepath=hubert_index_filepath,
            ratio=cli_args.retrieval_ratio,
            n_nearest_vectors=cli_args.n_retrieval_vectors
        ),
        whisper_index=load_retrieve_index(
            filepath=whisper_index_filepath,
            ratio=cli_args.retrieval_ratio,
            n_nearest_vectors=cli_args.n_retrieval_vectors
        ),
    )


def load_svc_model(checkpoint_path, model):
    assert os.path.isfile(checkpoint_path)
    checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
    saved_state_dict = checkpoint_dict["model_g"]
    state_dict = model.state_dict()
    new_state_dict = {}
    for k, v in state_dict.items():
        try:
            new_state_dict[k] = saved_state_dict[k]
        except:
            print("%s is not in the checkpoint" % k)
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    return model


def svc_infer(model, retrieval: IRetrieval, spk, pit, ppg, vec, hp, device):
    len_pit = pit.size()[0]
    len_vec = vec.size()[0]
    len_ppg = ppg.size()[0]
    len_min = min(len_pit, len_vec)
    len_min = min(len_min, len_ppg)
    pit = pit[:len_min]
    vec = vec[:len_min, :]
    ppg = ppg[:len_min, :]

    with torch.no_grad():
        spk = spk.unsqueeze(0).to(device)
        source = pit.unsqueeze(0).to(device)
        source = model.pitch2source(source)
        pitwav = model.source2wav(source)
        write("svc_out_pit.wav", hp.data.sampling_rate, pitwav)
        print('pit?',pitwav.shape)

        hop_size = hp.data.hop_length
        all_frame = len_min
        hop_frame = 10
        out_chunk = 2500  # 25 S
        out_index = 0
        out_audio = []
        print(all_frame)
        while (out_index < all_frame):

            if (out_index == 0):  # start frame
                cut_s = 0
                cut_s_out = 0
            else:
                cut_s = out_index - hop_frame
                cut_s_out = hop_frame * hop_size

            if (out_index + out_chunk + hop_frame > all_frame):  # end frame
                cut_e = all_frame
                cut_e_out = -1
            else:
                cut_e = out_index + out_chunk + hop_frame
                cut_e_out = -1 * hop_frame * hop_size

            sub_ppg = retrieval.retriv_hubert(ppg[cut_s:cut_e, :])
            sub_vec = retrieval.retriv_whisper(vec[cut_s:cut_e, :])
            sub_ppg = sub_ppg.unsqueeze(0).to(device)
            sub_vec = sub_vec.unsqueeze(0).to(device)
            sub_pit = pit[cut_s:cut_e].unsqueeze(0).to(device)
            sub_len = torch.LongTensor([cut_e - cut_s]).to(device)
            sub_har = source[:, :, cut_s *
                             hop_size:cut_e * hop_size].to(device)
            sub_out = model.inference(
                sub_ppg, sub_vec, sub_pit, spk, sub_len, sub_har)
            print('inference?',sub_out.shape)
            
            sub_out = sub_out[0, 0].data.cpu().detach().numpy()

            sub_out = sub_out[cut_s_out:cut_e_out]
            out_audio.extend(sub_out)
            out_index = out_index + out_chunk

        out_audio = np.asarray(out_audio)
        print('out?',out_audio.shape)
    return out_audio


def main(args):
    if (args.ppg == None):
        args.ppg = "svc_tmp.ppg.npy"
        print(
            f"Auto run : python whisper/inference.py -w {args.wave} -p {args.ppg}")
        os.system(f"python whisper/inference.py -w {args.wave} -p {args.ppg}")

    if (args.vec == None):
        args.vec = "svc_tmp.vec.npy"
        print(
            f"Auto run : python hubert/inference.py -w {args.wave} -v {args.vec}")
        os.system(f"python hubert/inference.py -w {args.wave} -v {args.vec}")

    if (args.pit == None):
        args.pit = "svc_tmp.pit.csv"
        print(
            f"Auto run : python pitch/inference.py -w {args.wave} -p {args.pit}")
        os.system(f"python pitch/inference.py -w {args.wave} -p {args.pit}")

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hp = OmegaConf.load(args.config)
    model = SynthesizerInfer(
        hp.data.filter_length // 2 + 1,
        hp.data.segment_size // hp.data.hop_length,
        hp)
    load_svc_model(args.model, model)
    retrieval = create_retrival(args)
    print(retrieval)
    model.eval()
    model.to(device)

    spk = np.load(args.spk)
    spk = torch.FloatTensor(spk)

    ppg = np.load(args.ppg)
    ppg = np.repeat(ppg, 2, 0)  # 320 PPG -> 160 * 2
    ppg = torch.FloatTensor(ppg)
    # ppg = torch.zeros_like(ppg)

    vec = np.load(args.vec)
    vec = np.repeat(vec, 2, 0)  # 320 PPG -> 160 * 2
    vec = torch.FloatTensor(vec)
    # vec = torch.zeros_like(vec)

    pit = load_csv_pitch(args.pit)
    print("pitch shift: ", args.shift)
    if (args.shift == 0):
        pass
    else:
        pit = np.array(pit)
        source = pit[pit > 0]
        source_ave = source.mean()
        source_min = source.min()
        source_max = source.max()
        print(f"source pitch statics: mean={source_ave:0.1f}, \
                min={source_min:0.1f}, max={source_max:0.1f}")
        shift = args.shift
        shift = 2 ** (shift / 12)
        pit = pit * shift
    pit = torch.FloatTensor(pit)

    out_audio = svc_infer(model, retrieval, spk, pit, ppg, vec, hp, device)
    write(f"{args.out}.wav", hp.data.sampling_rate, out_audio)


class args:
    config = 'configs/base.yaml'
    model = 'soyo_avg.pth'
    spk = r'./data_svc\speaker\soyo\soyo夹_51.spk.npy'                 
    #'./data_svc\speaker\soyo\キリトリセン_6.spk.npy' 
    wave = "1528567343取.wav"
    shift = 0
    out = '1528567343'
    ppg = None
    vec = None
    pit = None
    enable_retrieval = False
    retrieval_index_prefix = ''
    hubert_index_path = None
    whisper_index_path = None
    
    retrieval_ratio = 0.5
    n_retrieval_vectors = 3
    
    debug = False

In [20]:
main(args=args)

Auto run : python whisper/inference.py -w 1528567343取.wav -p svc_tmp.ppg.npy
Auto run : python hubert/inference.py -w 1528567343取.wav -v svc_tmp.vec.npy
Auto run : python pitch/inference.py -w 1528567343取.wav -p svc_tmp.pit.csv


INFO:__main__:infer without retrival


<feature_retrieval.retrieval.DummyRetrieval object at 0x000002608B2CD1F0>
pitch shift:  0
pit? (120320,)
376
inference? torch.Size([1, 1, 120320])
out? (120319,)


In [14]:
!python svc_inference_post.py --ref 詩超絆1.wav --svc svc_out.wav --out svc_out_post.wav

svc in wave : 詩超絆1.wav
svc out wave : svc_out.wav
svc post wave : svc_out_post.wav


  x, sr = librosa.load(file, sr=sr)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Traceback (most recent call last):
  File "e:\.env\sovit5\lib\site-packages\librosa\core\audio.py", line 175, in load
    y, sr_native = __soundfile_load(path, offset, duration, dtype)
  File "e:\.env\sovit5\lib\site-packages\librosa\core\audio.py", line 208, in __soundfile_load
    context = sf.SoundFile(path)
  File "e:\.env\sovit5\lib\site-packages\soundfile.py", line 658, in __init__
    self._file = self._open(file, mode_int, closefd)
  File "e:\.env\sovit5\lib\site-packages\soundfile.py", line 1216, in _open
    raise LibsndfileError(err, prefix="Error opening {0!r}: ".format(self.name))
soundfile.LibsndfileError: Error opening '詩超絆1.wav': System error.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "e:\Data\Code\so-vits-svc-5.0\svc