In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from vits.utils.utils import load_wav_to_torch, plot_spectrogram_to_numpy
from vits.utils.mel_processing import spectrogram_torch, mel_spectrogram_torch, spec_to_mel_torch

from vits.model import commons
from vits.utils import utils
from vits.model.models import SynthesizerTrn
from vits.text.symbols import symbols
from vits.text import cleaned_text_to_sequence, text_to_sequence, batch_text_to_sequence

from scipy.io.wavfile import write


def get_text(text, hps, language_code):
    text_norm = text_to_sequence(text, str(language_code))
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0) 
    text_norm = torch.LongTensor(text_norm)
    return text_norm

DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1023)
           2	LOAD_FAST(arg=0, lineno=1026)
           4	LOAD_CONST(arg=1, lineno=1026)
           6	BINARY_SUBSCR(arg=None, lineno=1026)
           8	LOAD_FAST(arg=0, lineno=1026)
          10	LOAD_CONST(arg=2, lineno=1026)
          12	BINARY_SUBSCR(arg=None, lineno=1026)
          14	COMPARE_OP(arg=4, lineno=1026)
          16	LOAD_FAST(arg=0, lineno=1026)
          18	LOAD_CONST(arg=1, lineno=1026)
          20	BINARY_SUBSCR(arg=None, lineno=1026)
          22	LOAD_FAST(arg=0, lineno=1026)
          24	LOAD_CONST(arg=3, lineno=1026)
          26	BINARY_SUBSCR(arg=None, lineno=1026)
          28	COMPARE_OP(arg=5, lineno=1026)
          30	BINARY_AND(arg=None, lineno=1026)
          32	RETURN_VALUE(arg=None, lineno=1026)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_

In [2]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cuda'

In [3]:
hps = utils.get_hparams_from_file("vits/configs/vits_base.json")
checkpoint_name = "vits_pl_test"
model_name = "G_40000"
CUDA_LAUNCH_BLOCKING=1
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    n_speakers=61,
    **hps.model).to(device)
_ = net_g.eval()

_ = utils.load_checkpoint(f"vits/checkpoints/{checkpoint_name}/{model_name}.pth", net_g, None)

  WeightNorm.apply(module, name, dim)


INFO:root:Loaded checkpoint 'vits/checkpoints/vits_pl_test/G_40000.pth' (iteration 57)


In [7]:
language_code = 0
input_text = "농축수산물 가격이 많이 올랐네요."
stn_text = get_text(input_text, hps, str(language_code))
sid=None
spec=None
speaker_path = "/data/dataset/anam/001_jeongwon_perturbed_ss_ep2_0.3_mask/0a97c2e2730097a2a7bed8b792384bb2.wav"
# speaker_path = "/data/gyub/VITS/datasets_모음/datasets_random2/231004_jeongwon/wavs/00d843423158cc5dfb4c24247215fca9.wav"
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

with torch.no_grad():
    x_tst = stn_text.to(device).unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_text.size(0)]).to(device)

    ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
    ref_audio_norm = ref_audio.unsqueeze(0)
    spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
    sid = torch.LongTensor([10]).to(device)
    audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0, noise_scale_w=0, length_scale=1.3)

ipd.display(ipd.Audio(audio[0][0,0].data.cpu().float().numpy(), rate=hps.data.sampling_rate, normalize=False))

# from scipy.io.wavfile import write
# write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)



In [4]:
from sae.models.model import SparseAE

language_code = 0
input_text = "안녕하세요, 잘 부탁드립니다. 당신은 누구신가요?"
stn_text = get_text(input_text, hps, str(language_code))
sid=None
spec=None
speaker_path = "/data/dataset/anam/001_jeongwon_perturbed/0aa7e22827e5d591a6d859ff9ba74d09.wav"
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

sae = SparseAE(192, 192 * 16, 1.0, 0, True).to(device)
sae.load_state_dict(torch.load('/data/youngjae/vits_pl/src/sae/checkpoints/vits_z/30epoch_1.0_0.1_0.1/sae_final.pth'))
sae.eval()

with torch.no_grad():
    x_tst = stn_text.to(device).unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_text.size(0)]).to(device)

    ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
    ref_audio_norm = ref_audio.unsqueeze(0)
    spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
    sid = torch.LongTensor([0]).to(device)

    
    z, y_mask, g, max_len = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0, noise_scale_w=0, length_scale=1, return_z=True)
    
    h, z_hat =sae(z.squeeze(0).T)
    z_hat = z_hat.T.unsqueeze(0)

    print(z.shape)
    print(z_hat.shape)

    audio = net_g.from_z(z, y_mask, g, max_len)
    audio_hat = net_g.from_z(z_hat, y_mask, g, max_len)

ipd.display(ipd.Audio(audio[0,0].data.cpu().float().numpy(), rate=hps.data.sampling_rate, normalize=False))
ipd.display(ipd.Audio(audio_hat[0,0].data.cpu().float().numpy(), rate=hps.data.sampling_rate, normalize=False))

# from scipy.io.wavfile import write
# write(f"original.wav", 22050, audio[0,0].data.cpu().float().numpy())

DEBUG:git.cmd:Popen(['git', 'version'], cwd=/data/youngjae/vits_pl/src, stdin=None, shell=False, universal_newlines=False)
DEBUG:git.cmd:Popen(['git', 'version'], cwd=/data/youngjae/vits_pl/src, stdin=None, shell=False, universal_newlines=False)
DEBUG:wandb.docker.auth:Trying paths: ['/home/ubuntu/.docker/config.json', '/home/ubuntu/.dockercfg']
DEBUG:wandb.docker.auth:No config file found
torch.Size([1, 192, 261])
torch.Size([1, 192, 261])


In [2]:
# def dataCollate(txt_lists):
#     max_text_len = max([len(x) for x in txt_lists])
#     text_lengths = torch.LongTensor(len(txt_lists))
#     text_padded = torch.LongTensor(len(txt_lists), max_text_len)
#     text_padded.zero_()

#     for i in range(len(txt_lists)):
#         text = torch.LongTensor(txt_lists[i])
#         text_padded[i, :len(text)] = text
#         text_lengths[i] = len(text)

#     return text_padded, text_lengths

# def find_mask(audio):
#     mask = 0
#     for i in range(len(audio)-1, 0, -1):
#         if torch.abs(audio[i]) < 0.01 and torch.mean(torch.abs(audio[i-100:i])) < 0.001:
#             mask = i
#             break
    
#     return mask

# def concat_audio(audio_list):
#     final_audio = []

#     for audio in audio_list:
#         mask = find_mask(audio[0].data)
#         final_audio.append(audio[0].data[:mask])

#     final_audio = torch.cat(final_audio, dim=0)
#     return final_audio

# language_code = 0
# group_size = 5
# input_text = "안녕하세요. 당신은 누구신가요? 제 이름은 김영재입니다."
# txt_lists = batch_text_to_sequence(input_text, str(language_code), group_size)
# text_padded, text_lengths = dataCollate(txt_lists)

# sid=None
# spec=None
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

# with torch.no_grad():

#     x_tst = text_padded.to(device)
#     x_tst_lengths = text_lengths.to(device)

#     ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
#     ref_audio_norm = ref_audio.unsqueeze(0)
#     spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
#     # sid = torch.LongTensor([32]).to(device)
#     audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0.333, noise_scale_w=0.1, length_scale=1)

# final_audio = concat_audio(audio[0]).cpu().float().numpy()
# ipd.display(ipd.Audio(final_audio, rate=hps.data.sampling_rate, normalize=False))

# # from scipy.io.wavfile import write
# # write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)

In [11]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd
import os
import numpy as np
import os
import random

hps = utils.get_hparams_from_file("vits/configs/vits_ref.json")

file_dir = "/data/gyub/VITS/datasets_모음/datasets_random2/231004_jeongwon/wavs"
files = os.listdir(file_dir)
# random.shuffle(files)

# files = files[:10]


for file in files:
    if ".pt" in file:
        continue
    file_path = f"{file_dir}/{file}"

    ref_audio, _ = load_wav_to_torch(file_path, 22050)
    ref_audio_norm = ref_audio.unsqueeze(0)

    spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False)

    mel = spec_to_mel_torch(
        spec,
        hps.data.filter_length,
        hps.data.n_mel_channels,
        hps.data.sampling_rate,
        hps.data.mel_fmin,
        hps.data.mel_fmax,
    )

    fig, ax = plt.subplots(figsize=(10,2))
    im = ax.imshow(mel.squeeze(0), aspect="auto", origin="lower",
                    interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))


    out_path = os.path.join("rand_mel/", f"{file}.png")
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()


DEBUG:matplotlib.pyplot:Loaded backend module://matplotlib_inline.backend_inline version unknown.


DEBUG:matplotlib.colorbar:locator: <matplotlib.ticker.AutoLocator object at 0x7f8313128910>


  data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')


DEBUG:matplotlib.colorbar:locator: <matplotlib.ticker.AutoLocator object at 0x7f83418a65e0>
DEBUG:matplotlib.colorbar:locator: <matplotlib.ticker.AutoLocator object at 0x7f83103c2a00>
DEBUG:matplotlib.colorbar:locator: <matplotlib.ticker.AutoLocator object at 0x7f8313a371f0>
DEBUG:matplotlib.colorbar:locator: <matplotlib.ticker.AutoLocator object at 0x7f8313bf3b20>
DEBUG:matplotlib.colorbar:locator: <matplotlib.ticker.AutoLocator object at 0x7f8302bf5fd0>


KeyboardInterrupt: 

<Figure size 1000x200 with 0 Axes>