In [1]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import numpy as np
np.random.seed(0)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import librosa
import yaml

from models import *
from utils import *



In [3]:
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(path):
    wave, sr = librosa.load(path, sr=24000)
    wave, index = librosa.effects.trim(wave, top_db=30)
    if sr != 24000:
        wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
    wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)

    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
config = yaml.safe_load(open("logs/aihub_muilti_lingual_en_jp_ko2/config_en_jp_ko2.yml"))
# config = yaml.safe_load(open("logs/LibriTTS/config.yml"))

model = build_model(model_params=config["model_params"])
model.eval()
model.to(device)

In [None]:
# checkpoint_path = "logs/aihub_muilti_lingual_en_jp_ko2/epoch_00080.pth"
# checkpoint_path = "logs/LibriTTS/epoch_00080.pth"
# checkpoint_path = "logs/aihub_female/epoch_1st_00000.pth"
# checkpoint_path = "logs/aihub_va_all/epoch_1st_00000.pth"
# checkpoint_path = "logs/jp_valid_word/epoch_1st_00003.pth"
checkpoint_path = "logs/eng/epoch_1st_00000.pth"


state_dict = torch.load(checkpoint_path, map_location="cpu")
# model.load_state_dict(state_dict["model"])
model.load_state_dict(state_dict["net"]["text_aligner"])


In [57]:
#--- multi lingual vocabs
_pad = "$"
_punctuation = ';:,.!?¡¿—…\'"«»“”()-=^&*~ '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"

_letters_jp = ['by', 'ch', 'cl', 'd', 'dy', 'gy', 'hy', 'ky', 'my', 'ny', 'pau', 'py', 'ry', 'sh', 'ts', 'ty']

_letter_ko_JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
_letter_ko_JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
_letter_ko_JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
_letter_ko_CHARS = _letter_ko_JAMO_LEADS + _letter_ko_JAMO_VOWELS + _letter_ko_JAMO_TAILS

symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + _letters_jp + list(_letter_ko_CHARS)
# ---

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text, cleaned=False, lang='en'):
        indexes = []
        # phonemize the text
        if not cleaned:
            raise ValueError('Not implemented')
        else:  
            pho = text
        # convert phonemized text to indexes
        for char in pho:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print('KeyError:', char, pho)
        return indexes

textclenaer = TextCleaner()

In [20]:
# Korean vocab

PAD = '_'
BOS = '<bos>'
EOS = '<eos>'
PUNC = '!?\'\"().,-=:;^&*~'
SPACE = ' '
_SILENCES = ['sp', 'spn', 'sil']

JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])

VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE
symbols = [PAD] + [BOS] + [EOS] + list(VALID_CHARS) + _SILENCES

#---
dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text, cleaned=True):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(f"Unknown character: {char}")
                print(text)
        return indexes

textclenaer = TextCleaner()

In [13]:
# Original

_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"

# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self):
        self.word_index_dictionary = dicts
    def __call__(self, text, cleaned=True):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(text)
        return indexes

textclenaer = TextCleaner()

In [8]:
from PIL import Image

def get_image(arrs):
    pil_images = []
    height = 0
    width = 0
    for arr in arrs:
        uint_arr = (((arr - arr.min()) / (arr.max() - arr.min())) * 255).astype(np.uint8)
        pil_image = Image.fromarray(uint_arr)
        pil_images.append(pil_image)
        height += uint_arr.shape[0]
        width = max(width, uint_arr.shape[1])

    palette = Image.new('L', (width, height))
    curr_heigth = 0
    for pil_image in pil_images:
        palette.paste(pil_image, (0, curr_heigth))
        curr_heigth += pil_image.size[1]

    return palette

In [None]:
import IPython.display as ipd

lines = [
    "voice_actor/Training_24k/original/Sportscaster/S-YZ-A-022-0240.wav|점수가 업치락뛰치라카면서 오늘 경기 보시는 분들 정말 재미이쓸 껃 가타요.|22-S-YZ",
    "multi_lingual/jp/TS_jp/1-jp-100020_1054_1-M-41-100020_1054_1_32.wav|<jp>k o n o y o o n a k i n o o g a k e e s e e s a r e r u m e k a n i z u m u w a pau m a z u w a sh a k a i n o k o p i i o ts u u j i t e k e e s e e s a r e r u|100348",
    "multi_lingual/TS_en/5-en-100019_944_1-F-25-100019_944_1_10.wav|fɔːɹ ɛɡzˈæmpəl, ɪn kɚɹˈiːən səsˈaɪəɾi, wˌɛɹ ðɪ ˌaɪdɪəlˈɑːdʒɪkəl ˌɑːpəzˈɪʃən bɪtwˌiːn kənsˈɜːvətˌɪvz ænd pɹəɡɹˈɛsɪvz ɪz ɪntˈɛns, pˈʌblɪk ˈɪntɹəst ɪz ɪnˈɛvɪɾəbli vˈɛɹi ˈæbstɹækt.|100204",
    "multi_lingual/TS_en/5-en-100019_938_1-M-27-100019_938_1_13.wav|ɪt ɡˈɪvz ðə pɚsˈɛpʃən ðæt fɹˈiːdəm ʌv ɛkspɹˈɛʃən dʌznˌɑːt ɛndʒˈɔɪ ɐ suːpˈiəɹɪɚ stˈæɾəs kəmpˈɛɹd tʊ ˈʌðɚ fˌʌndəmˈɛntəl ɹˈaɪts.|100232",
    "voice_actor/Training_24k/original/Dialogue/D-S3-A-009-0108.wav|가슴과 머리가 터질 드시 아파써, 우르미 터저 나와써.|9-D-S3",
# "Demo/reference_audio/sad_english.mp3|həlˈoʊ , θˈæŋk juː fɔːɹ vˈɪzɪɾɪŋ ðɪ ɛksɪbˈɪʃən ʌv ɡlˈoʊbəl ˈoʊpən ˌɪnəvˈeɪʃən æktˈɪvɪɾiz .|111"
]

lines = [
    "voice_actor/Training_24k/original/Kindly/K-NZ-G-047-0410.wav|오전 열 시 사십 분쯔메는 전남 여천군 율촌면 우루과이라운드 빨리 조치를 취하도록 하겓씀니다.|47",
    "voice_actor/Training_24k/original/Narration/N-NX-G-047-0424.wav|당시 에이씨는 비씨의 멱싸를 소느로 자븐 후 발로 허벅찌를 오 회가량 거더찬 거스로 조사됃따.|47",
    "voice_actor/Training_24k/original/Fairy/F-NY-C-048-0197.wav|곧 아이드리 원시사회로부터 하나 둘씩 모여드럳씀니다.|48",
    "voice_actor/Training_24k/original/Fairy/F-NY-C-048-0133.wav|노리터에서 책까방을 멘 친구드리 하나둘 모엳씀니다.|48",
    "voice_actor/Training_24k/original/Animation/A-NX-C-048-0438.wav|현시를 어느 정도 바다드리지 모타면 방어기제 발똥 상황을 맏께 돼요.|48",
]


lines = [
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_5-F-51-100020_3329_5_24.wav|<jp>g e N j i ts u d e m o b o o ry o k u g a h a cl s e e s u r u b a a i g a a r u|100166",
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_6-F-51-100020_3329_6_43.wav|<jp>gy u N t a a g a t e e j i sh I t a b o o ry o k u n o sh u r u i w a s u b e t e m e n i m i e r u b u ts u r i t e k i b o o ry o k u d e a r u|100166",
    "multi_lingual/jp/TS_jp/5-jp-100020_3359_1-F-28-100020_3359_1_26.wav|<jp>s a i k i N pau k a N k o k u d e w a k o o ky o o b u m o N o ch u u sh i N n i k a cl p a ts u n i d o o ny u u sh I t e o r i pau d a i g a k u d e m o ny u u g a k U s a t e e n i i ch i b u k a ts u y o o sh I t e i r u|100447",
    "multi_lingual/jp/TS_jp/5-jp-100020_3360_10-F-34-100020_3360_10_49.wav|<jp>k o n o s e e s a k U k e cl t e e n o r o N r i w a sh i N p u r u d a|100833",
    "multi_lingual/jp/TS_jp/5-jp-100020_3360_3-F-31-100020_3360_3_45.wav|<jp>k a k U sh i N w a y o o s e k i r i ts u o t a k a m e r u k a i h a ts u y a sh i N t o sh i o ts u u j i t a ky o d a i n a j u u t a k u d a N ch i n o f U ky u u d e w a n a k u pau j u u t a k U s e e s a k u n o h o o k o o j i t a i n o t e N k a N d a|100782",
    "multi_lingual/jp/TS_jp/5-jp-100020_3360_6-F-32-100020_3360_6_6.wav|<jp>d a r e g a m i t e m o k I k e N n i m i e r u r o o ky u u sh I t a j i d o o k o o e N n i n o m i n o r i g a t e k i y o o s a r e t a n o d e w a n a i|100786",
    "multi_lingual/jp/TS_jp/5-jp-100020_3399_40-F-45-100020_3399_40_22.wav|<jp>sh i i e s U k a N r i sh a w a pau sh i i e s u sh a i N g a k o ky a k u n i y a r a r e r u n o o z a N n e N g a r u d a k e d e pau k o n o h o o h o o o k o N p o N t e k i n i k a i k e ts U s u r u h o o h o o o sh i cl t e i n a i y o o d e s U|100843",
    "multi_lingual/jp/TS_jp/5-jp-100020_3399_54-F-25-100020_3399_54_21.wav|<jp>sh I k a sh i pau ch u u i s u r u k o t o w a pau k o ky a k u n o k a N j o o o k i z u ts U k e t a k o t o n i t a i sh I t e n o m i sh a z a i sh i pau k o ky a k u n o a y a m a ch i m a d e s u b e t e j i b u N n o s e i t o m i t o m e t e w a i k e m a s e N|100824",
    "multi_lingual/jp/TS_jp/5-jp-100020_3410_9-F-27-100020_3410_9_18.wav|<jp>m a t a pau s o sh I k i sh i m i N k o o d o o n o k u r a i s o k u m e N n i m o ch u u m o k U sh i n a k e r e b a n a r a n a i|100821",
    "multi_lingual/jp/TS_jp/5-jp-100020_868_1-F-50-100020_868_1_27.wav|<jp>t e k u n o r o j i i g a t a N n i n i N g e N n o k a N k a k u pau s e e sh i N pau j i N s e e n i ch o k U s e ts u e e ky o o o o y o b o s U t o m i r u n o w a pau g i j u ts U k e cl t e e r o N t e k i n a sh I k o o n o t e N k e e n i s u g i n a i t o i u k o t o d a|100316",
]

lines = [
    "multi_lingual/TS_en/3-en-100019_539_1-F-23-100019_539_1_35.wav|ˌɛldʒˈiː, wˌɪtʃ wˈʌn ðə tʃˈæmpiənʃˌɪp fɚðə sˈɛkənd tˈaɪm sˈɪns ɪts faʊndˈeɪʃən ɪn nˈaɪntiːnhˈʌndɹəd nˈaɪntifˈoːɹ, tʃˈeɪndʒd ðə nˈeɪm ʌv ɪts pˈɛɹənt kˈʌmpəni wɪððə vˈɪktɚɹi ænd səksˈiːdᵻd ɪn mˈænɪdʒmənt.|100167",
    "multi_lingual/TS_en/3-en-100019_588_1-F-29-100019_588_1_9.wav|ðɛɹˌɑːɹ mˈoːɹ wˈɜːks ʌv lˈɪɾɚɹətʃɚɹ ɐbˌaʊt ðə fˈækt ðæt pˈiːpəl pɚsˈuː bjˈuːɾi ænd lˈʌv bjˈuːɾɪfəl θˈɪŋz ðɐn juː mˌaɪt θˈɪŋk.|100210",
    "multi_lingual/TS_en/3-en-100019_787_1-F-35-100019_787_1_30.wav|haʊˈɛvɚ, ðɛɹ ɪz ɐ pɹˈɑːbləm ðæt dəmˈɛstɪk vˌiːˈɑːɹ wɛbtˈuːn kˈɑːntɛnt pɹədˈuːsɚz ɑːɹ nˌɑːt ˈeɪbəl tʊ ˈɛntɚ smˈuːðli æz ðə hˈɛd ˈɑːfɪs dɚɹˈɛktli mˈænɪdʒᵻz ænd kəntɹˈoʊlz ˈɑːkjʊləs ɡˌoʊz dˌɪstɹɪbjˈuːʃən lˈaɪn ʌntˈɪl nˈaʊ.|100451",
    "multi_lingual/TS_en/3-en-100019_966_1-F-29-100019_966_1_11.wav|ðˈɛn, juː wɪl pˈæs baɪ kˈɪm mˈɪnbˈuː ɑːbzˈɜːvətˌoːɹi ænd ˈiːbəɡˈuː wˈɜːkʃɑːp, wˌɛɹ juː kæn sˈiː bjˈuːsən pˈoːɹt.|100210",
    "multi_lingual/TS_en/5-en-100019_131_1-F-31-100019_131_1_4.wav|ðɪs ɪz bɪkˈʌz sˈɜːtən kˈʌltʃɚɹəl nˈoʊʃənz kæn ɡˈɪv lˈɛksɪkəl kˌæɹɪktɚɹˈɪstɪks tʊ ɐn ˌaɪkənəɡɹˈæfɪk tˈɛkst.|100444",
    "multi_lingual/TS_en/5-en-100019_1409_3-F-26-100019_1409_3_30.wav|bɪhˌaɪnd ðə ɹˈɛɾɚɹˌɪk nˈɛɡɹi kˈɔːlz ðɪ ˈɛmpaɪɚɹ ɪz ðə juːnˈaɪɾᵻd stˈeɪts.|100205",
    "multi_lingual/TS_en/5-en-100019_1410_17-F-25-100019_1410_17_21.wav|ɪnðə nˈaɪntiːnhˈʌndɹəd nˈaɪnti z, tʃˈaɪnəz ˌiːkənˈɑːmɪk dɪvˈɛləpmənt bɪɡˈæn təbi ɐksˈɛlɚɹˌeɪɾᵻd.|100202",
    "multi_lingual/TS_en/5-en-100019_1412_13-F-30-100019_1412_13_8.wav|ɹɪsˈɜːtʃɚz ʃˌʊd biː ɪspˈɛʃəli kˈɛɹfəl nˌɑːt tə ɹˈʌʃ tə kənklˈuːʒənz ʌntˈɪl ðeɪ lˈiːv ðɛɹ ɹɪsˈɜːtʃ sˈaɪt.|100208",
]

# lines = [
#     "voice_actor/Training_24k/original/Kindly/K-NZ-G-047-0410.wav|오전 열 시 사십 분쯔메는 전남 여천군 율촌면 우루과이라운드 빨리 조치를 취하도록 하겓씀니다.|47",
#     "voice_actor/Training_24k/original/Narration/N-NX-G-047-0424.wav|당시 에이씨는 비씨의 멱싸를 소느로 자븐 후 발로 허벅찌를 오 회가량 거더찬 거스로 조사됃따.|47",
#     "voice_actor/Training_24k/original/Fairy/F-NY-C-048-0197.wav|곧 아이드리 원시사회로부터 하나 둘씩 모여드럳씀니다.|48",
#     "voice_actor/Training_24k/original/Fairy/F-NY-C-048-0133.wav|노리터에서 책까방을 멘 친구드리 하나둘 모엳씀니다.|48",
#     "voice_actor/Training_24k/original/Animation/A-NX-C-048-0438.wav|현시를 어느 정도 바다드리지 모타면 방어기제 발똥 상황을 맏께 돼요.|48",
# ]

lines = [
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_3-F-24-100020_3329_3_44.wav|<jp>b o o ry o k u t e k i n a d e N sh i g e e m u g a m o ts U ky u u i N ry o k u o r i k a i s u r u m o k U t e k i d e o k o n a w a r e t a k e N ky u u k e cl k a o m i t e m i y o o|100823",
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_4-F-24-100020_3329_4_2.wav|<jp>i k a n a r u b o o ry o k u m o b i k a s a r e t a r i s e e t o o k a s a r e r u k o t o g a a cl t e w a n a r a n a i|100823",
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_4-F-24-100020_3329_4_32.wav|<jp>i cl p a N t e k i n a b o o ry o k u d e w a pau k a sh i t e k i d e pau k a N t a N n i k a N s a ts u d e k i r u sh i N t a i t e k I sh o o g a i o m a s u m a s U ky o o ch o o s u r u|100823",
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_4-F-24-100020_3329_4_38.wav|<jp>g e e m u r i y o o ch u u n i b o o ry o k U s e e n i s a r a s a r e r u t o pau j i cl s a i n i k o o g e k I s e e o y o o n i N sh i m i t o m e r u k e e k o o g a a r i pau sh i d a i n i b o o ry o k u n i d o N k a N n i n a r i pau b o o ry o k u o t o o z e N sh I s u r u k a n o o s e e g a a r u|100823",
    "multi_lingual/jp/TS_jp/5-jp-100020_3329_4-F-24-100020_3329_4_40.wav|<jp>s a r a n i k o d o m o o t a i sh o o t o s u r u m a N g a d e m o b o o ry o k u t e k i n a n a i y o o t o sh i i N g a m i ts U k e r a r e r u|100823",
]


root_dir = '/home/jovyan/datasets/AIHUB/'
images = []
for i, line in enumerate(lines):
    path, text, _ = line.split('|')
    print(text)
    path = os.path.join(root_dir, path)
    if text[:4] == '<jp>':
        text = text[4:].split()
    text_input = textclenaer(text, cleaned=True)
    text_input.insert(0, 0)
    text_input.append(0)
    text_input = torch.LongTensor(text_input).unsqueeze(0).to(device)
    text_input_length = torch.zeros(1).long().to(device)

    mel_input = preprocess(path).to(device)
    mel_input = mel_input[:, :, :mel_input.size(-1) -1]
    mel_input_length = torch.zeros(1).long().to(device)
    mel_input_length[0] = mel_input.shape[-1]

    with torch.no_grad():
        mel_input_length = mel_input_length // (2 ** model.n_down)
        mel_input = mel_input[..., :mel_input_length * (2 ** model.n_down)]
        mel_mask = length_to_mask(mel_input_length).to(device)

        text_mask = length_to_mask(text_input_length).to(device)
        ppgs, s2s_pred, s2s_attn = model(
            mel_input, src_key_padding_mask=mel_mask, text_input=text_input)
        # images.append(get_image([s2s_attn[0].cpu().numpy()]))

        # plot one s2s attention

        plt.figure(figsize=(10, 5))
        plt.imshow(s2s_attn.squeeze().cpu().numpy(), aspect='auto', origin='lower')
        plt.colorbar()
        plt.show()

        ref, sr = librosa.load(path, sr=24000)
        ipd.display(ipd.Audio(ref, rate=sr))



In [None]:
plt.imshow(images[2])
plt.show()

In [None]:
# model reload
config = yaml.safe_load(open("logs/aihub_muilti_lingual_en_jp_ko2/config_en_jp_ko2.yml"))
# config = yaml.safe_load(open("logs/LibriTTS/config.yml"))

model = build_model(model_params=config["model_params"])
# model.eval()
model.to(device)
checkpoint_path = "logs/aihub_muilti_lingual_en_jp_ko2/epoch_00080.pth"
# checkpoint_path = "logs/LibriTTS/epoch_00080.pth"
# checkpoint_path = "logs/aihub_female/epoch_1st_00000.pth"
# checkpoint_path = "logs/aihub_va_all/epoch_1st_00000.pth"

# checkpoint_path = "logs/jp_valid_word/epoch_1st_00003.pth"
# checkpoint_path = "logs/jp/epoch_1st_00012.pth"
# checkpoint_path = "logs/eng/epoch_2nd_00007.pth"
# checkpoint_path = "logs/eng_4s/epoch_2nd_00004.pth"


state_dict = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state_dict["model"])
# model.load_state_dict(state_dict["net"]["text_aligner"])

# new_state_dict = {}
# for k, v in state_dict["net"]["text_aligner"].items():
#     k = k.replace("module.", "")
#     new_state_dict[k] = v
# model.load_state_dict(new_state_dict)
