In [17]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from vits.utils.utils import load_wav_to_torch
from vits.utils.mel_processing import spectrogram_torch

from vits.model import commons
from vits.utils import utils
from vits.model.models import SynthesizerTrn
from vits.text.symbols import symbols
from vits.text import cleaned_text_to_sequence, text_to_sequence, batch_text_to_sequence

from scipy.io.wavfile import write


def get_text(text, hps, language_code):
    text_norm = text_to_sequence(text, str(language_code))
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0) 
    text_norm = torch.LongTensor(text_norm)
    return text_norm

DEBUG:matplotlib.pyplot:Loaded backend module://matplotlib_inline.backend_inline version unknown.


In [18]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cuda'

In [30]:
hps = utils.get_hparams_from_file("vits/configs/vits_base.json")
checkpoint_name = "vits_pl_test"
model_name = "G_40000"
CUDA_LAUNCH_BLOCKING=1
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    n_speakers=61,
    **hps.model).to(device)
_ = net_g.eval()

_ = utils.load_checkpoint(f"vits/checkpoints/{checkpoint_name}/{model_name}.pth", net_g, None)

INFO:root:Loaded checkpoint 'vits/checkpoints/vits_pl_test/G_40000.pth' (iteration 57)


In [31]:
language_code = 0
input_text = "당신은 누구신지 여쭤보아도 될까요?"
stn_text = get_text(input_text, hps, str(language_code))
print(stn_text.shape)
sid=None
spec=None
speaker_path = "/data/dataset/anam/001_jeongwon_perturbed/0aa7e22827e5d591a6d859ff9ba74d09.wav"
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

with torch.no_grad():
    x_tst = stn_text.to(device).unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_text.size(0)]).to(device)

    ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
    ref_audio_norm = ref_audio.unsqueeze(0)
    spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
    sid = torch.LongTensor([32]).to(device)
    audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0, noise_scale_w=0, length_scale=1)


ipd.display(ipd.Audio(audio[0][0,0].data.cpu().float().numpy(), rate=hps.data.sampling_rate, normalize=False))

# from scipy.io.wavfile import write
# write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)

torch.Size([99])


In [None]:
language_code = 0
input_text = "당신은 누구신지 여쭤보아도 될까요?"
stn_text = get_text(input_text, hps, str(language_code))
print(stn_text.shape)
sid=None
spec=None
speaker_path = "/data/dataset/anam/001_jeongwon_perturbed/0aa7e22827e5d591a6d859ff9ba74d09.wav"
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

with torch.no_grad():
    x_tst = stn_text.to(device).unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_text.size(0)]).to(device)

    ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
    ref_audio_norm = ref_audio.unsqueeze(0)
    spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
    sid = torch.LongTensor([32]).to(device)

    
    audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0, noise_scale_w=0, length_scale=1)


ipd.display(ipd.Audio(audio[0][0,0].data.cpu().float().numpy(), rate=hps.data.sampling_rate, normalize=False))

# from scipy.io.wavfile import write
# write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)

In [None]:
# def dataCollate(txt_lists):
#     max_text_len = max([len(x) for x in txt_lists])
#     text_lengths = torch.LongTensor(len(txt_lists))
#     text_padded = torch.LongTensor(len(txt_lists), max_text_len)
#     text_padded.zero_()

#     for i in range(len(txt_lists)):
#         text = torch.LongTensor(txt_lists[i])
#         text_padded[i, :len(text)] = text
#         text_lengths[i] = len(text)

#     return text_padded, text_lengths

# def find_mask(audio):
#     mask = 0
#     for i in range(len(audio)-1, 0, -1):
#         if torch.abs(audio[i]) < 0.01 and torch.mean(torch.abs(audio[i-100:i])) < 0.001:
#             mask = i
#             break
    
#     return mask

# def concat_audio(audio_list):
#     final_audio = []

#     for audio in audio_list:
#         mask = find_mask(audio[0].data)
#         final_audio.append(audio[0].data[:mask])

#     final_audio = torch.cat(final_audio, dim=0)
#     return final_audio

# language_code = 0
# group_size = 5
# input_text = "안녕하세요. 당신은 누구신가요? 제 이름은 김영재입니다."
# txt_lists = batch_text_to_sequence(input_text, str(language_code), group_size)
# text_padded, text_lengths = dataCollate(txt_lists)

# sid=None
# spec=None
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

# with torch.no_grad():

#     x_tst = text_padded.to(device)
#     x_tst_lengths = text_lengths.to(device)

#     ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
#     ref_audio_norm = ref_audio.unsqueeze(0)
#     spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
#     # sid = torch.LongTensor([32]).to(device)
#     audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0.333, noise_scale_w=0.1, length_scale=1)

# final_audio = concat_audio(audio[0]).cpu().float().numpy()
# ipd.display(ipd.Audio(final_audio, rate=hps.data.sampling_rate, normalize=False))

# # from scipy.io.wavfile import write
# # write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)



KeyboardInterrupt: 