In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from vits.utils.utils import load_wav_to_torch
from vits.utils.mel_processing import spectrogram_torch

from vits.model import commons
from vits.utils import utils
from vits.model.models import SynthesizerTrn
from vits.text.symbols import symbols
from vits.text import cleaned_text_to_sequence, text_to_sequence, batch_text_to_sequence

from scipy.io.wavfile import write


def get_text(text, hps, language_code):
    text_norm = text_to_sequence(text, str(language_code))
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0) 
    text_norm = torch.LongTensor(text_norm)
    return text_norm

DEBUG:numba.core.byteflow:bytecode dump:
>          0	NOP(arg=None, lineno=1023)
           2	LOAD_FAST(arg=0, lineno=1026)
           4	LOAD_CONST(arg=1, lineno=1026)
           6	BINARY_SUBSCR(arg=None, lineno=1026)
           8	LOAD_FAST(arg=0, lineno=1026)
          10	LOAD_CONST(arg=2, lineno=1026)
          12	BINARY_SUBSCR(arg=None, lineno=1026)
          14	COMPARE_OP(arg=4, lineno=1026)
          16	LOAD_FAST(arg=0, lineno=1026)
          18	LOAD_CONST(arg=1, lineno=1026)
          20	BINARY_SUBSCR(arg=None, lineno=1026)
          22	LOAD_FAST(arg=0, lineno=1026)
          24	LOAD_CONST(arg=3, lineno=1026)
          26	BINARY_SUBSCR(arg=None, lineno=1026)
          28	COMPARE_OP(arg=5, lineno=1026)
          30	BINARY_AND(arg=None, lineno=1026)
          32	RETURN_VALUE(arg=None, lineno=1026)
DEBUG:numba.core.byteflow:pending: deque([State(pc_initial=0 nstack_initial=0)])
DEBUG:numba.core.byteflow:stack: []
DEBUG:numba.core.byteflow:state.pc_initial: State(pc_initial=0 nstack_

In [2]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cuda:3'

In [68]:
hps = utils.get_hparams_from_file("vits/configs/vits_base.json")
checkpoint_name = "attack_origin_single"
model_name = "G_3000"
CUDA_LAUNCH_BLOCKING=1
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    n_speakers=1,
    **hps.model).to(device)
_ = net_g.eval()

_ = utils.load_checkpoint(f"vits/checkpoints/{checkpoint_name}/{model_name}.pth", net_g, None)

INFO:root:Loaded checkpoint 'vits/checkpoints/attack_origin_single/G_3000.pth' (iteration 200)


In [71]:
def find_mask(audio, silence_thresh=0.01, window=100):
    """
    audio: 1D Tensor
    silence_thresh: 이 값보다 작으면 ‘무음’으로 간주
    window: 뒤에서부터 이만큼 길이의 구간 평균이 무음인지 확인
    """
    L = len(audio)
    # 최소 window 길이 만큼은 검사할 수 있도록
    for i in range(L-1, window-1, -1):
        if (audio[i].abs() < silence_thresh and audio[i-window:i].abs().mean() < silence_thresh/10):
            return i
    # 못 찾았다면 끝까지 보존
    return L

language_code = 0
input_text = "농축수산물 가격이 많이 올랐네요."
stn_text = get_text(input_text, hps, str(language_code))

sid=None
spec=None
speaker_path = "/data/dataset/anam/001_jeongwon_perturbed/0aa7e22827e5d591a6d859ff9ba74d09.wav"
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

with torch.no_grad():
    x_tst = stn_text.to(device).unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_text.size(0)]).to(device)

    ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
    ref_audio_norm = ref_audio.unsqueeze(0)
    spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
    # sid = torch.LongTensor([32]).to(device)
    audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0.333, noise_scale_w=0.1, length_scale=1.2)


# mask = find_mask(audio, silence_thresh=0.05, window=100)
# audio_trimmed = audio[:mask]


ipd.display(ipd.Audio(audio[0][0,0].data.cpu().float().numpy(), rate=hps.data.sampling_rate, normalize=False))

# from scipy.io.wavfile import write
# write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)



In [None]:
# def dataCollate(txt_lists):
#     max_text_len = max([len(x) for x in txt_lists])
#     text_lengths = torch.LongTensor(len(txt_lists))
#     text_padded = torch.LongTensor(len(txt_lists), max_text_len)
#     text_padded.zero_()

#     for i in range(len(txt_lists)):
#         text = torch.LongTensor(txt_lists[i])
#         text_padded[i, :len(text)] = text
#         text_lengths[i] = len(text)

#     return text_padded, text_lengths

# def find_mask(audio):
#     mask = 0
#     for i in range(len(audio)-1, 0, -1):
#         if torch.abs(audio[i]) < 0.01 and torch.mean(torch.abs(audio[i-100:i])) < 0.001:
#             mask = i
#             break
    
#     return mask

# def concat_audio(audio_list):
#     final_audio = []

#     for audio in audio_list:
#         mask = find_mask(audio[0].data)
#         final_audio.append(audio[0].data[:mask])

#     final_audio = torch.cat(final_audio, dim=0)
#     return final_audio

# language_code = 0
# group_size = 5
# input_text = "안녕하세요. 당신은 누구신가요? 제 이름은 김영재입니다."
# txt_lists = batch_text_to_sequence(input_text, str(language_code), group_size)
# text_padded, text_lengths = dataCollate(txt_lists)

# sid=None
# spec=None
# speaker_path = "/data/dataset/anam/001_jeongwon/wavs/ad2978265d57a4432c36c86cae5575ef.wav"

# with torch.no_grad():

#     x_tst = text_padded.to(device)
#     x_tst_lengths = text_lengths.to(device)

#     ref_audio, _ = load_wav_to_torch(speaker_path, 22050)
#     ref_audio_norm = ref_audio.unsqueeze(0)
#     spec = spectrogram_torch(ref_audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False).to(device)
    
#     # sid = torch.LongTensor([32]).to(device)
#     audio = net_g.infer(x_tst, x_tst_lengths, y=spec, sid=sid, noise_scale=0.333, noise_scale_w=0.1, length_scale=1)

# final_audio = concat_audio(audio[0]).cpu().float().numpy()
# ipd.display(ipd.Audio(final_audio, rate=hps.data.sampling_rate, normalize=False))

# # from scipy.io.wavfile import write
# # write(f"inference_files/{model_name}_inf.wav", 22050, final_audio)



KeyboardInterrupt: 