In [1]:
import json
import torch
import re
from pathlib import Path
from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config
from src.preprocessing.text.cleaners import english_cleaners
import subprocess
from scipy.io.wavfile import write as wav_write
from tqdm.notebook import tqdm


In [2]:
config = load_config("configs/esd_vctk_15.yml")

In [3]:
device = config.device

In [4]:
checkpoint_path = Path(f"checkpoints/{config.checkpoint_name}")

In [5]:
generators = [file for file in (checkpoint_path / "hifi").rglob("*.*") if file.name.startswith("g_")]

In [6]:
G2P_MODEL_PATH = "models/en/g2p/english_g2p.zip"
G2P_OUTPUT_PATH = "predictions/to_g2p.txt"

In [7]:
def text_to_file(user_query: str) -> None:
    text_path = Path("tmp.txt")
    with open(text_path, "w") as fout:
        normalized_content = english_cleaners(user_query)
        normalized_content = " ".join(re.findall("[a-zA-Z]+", normalized_content))
        fout.write(normalized_content)
    subprocess.call(
        ["mfa", "g2p", G2P_MODEL_PATH, text_path.absolute(), G2P_OUTPUT_PATH]
    )
    text_path.unlink()

In [8]:
default = {"he": "HH IY1", "she": "SH IY1", "we": "W IY1", "be": "B IY0", "the": "DH AH0", "whenever": "W EH0 N EH1 V ER0", "year": "AH0 Y IH1 R"}

def parse_g2p(PHONEMES_TO_IDS, g2p_path: str = G2P_OUTPUT_PATH):
    with open(g2p_path, "r") as fin:
        phonemes_ids = []
        phonemes = []
        phonemes_ids.append(PHONEMES_TO_IDS[""])
        for line in fin:
            word, word_to_phones = line.rstrip().split("\t", 1)
            if word in default:
                word_to_phones = default[word]
            phonemes.extend(word_to_phones.split(" "))
            phonemes_ids.extend(
                [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
            )
        phonemes_ids.append(PHONEMES_TO_IDS[""])
    return phonemes_ids

In [9]:
texts = [
    "Something must be done for them whenever they leave their current place and settle in a new home.",
    "He then really thought himself equal to it.",
    "I had felt that the best of life was over for me!",
    "I hope a hundred pounds a year would make them all comfortable.",
    "Perhaps it would have been as well if he had left it wholly to myself.",
    "A pleasant renewing of old acquaintances, that was all I had thought it, not foreseeing that I was shortly to plunge into this whole situation."
]

In [21]:
huawei_phones = """something	S AH1 M TH IH0 NG
must	M AH1 S T
be	B IY1
done	D AH1 N
for	F AO1 R
them	DH AH0 M
whenever	W EH0 N EH1 V ER0
they	DH EY1
leave	L IY1 V
their	DH EH2 R
current	K ER1 AH0 N T
place	P L EY1 S
and	AE2 N D
settle	S EH1 T AH0 L
in	IH0 N
a	AH0
new	N UW1
home	HH OW1 M


he	HH IY1
then	DH EH1 N
really	R IY1 L IY0
thought	TH AO1 T
himself	HH IH0 M S EH1 L F
equal	IY1 K W AH0 L
to	T OW0
it	IH0 T


i	AY1
had	HH AE1 D
felt	F EH1 L T
that	DH AE1 T
the	DH AH0
best	B EH1 S T
of	AA1 F
life	L AY1 F
was	W AA1 Z
over	OW2 V ER0
for	F AO1 R
me	M IY1


i	AY1
hope	HH OW1 P
a	AH0
hundred	HH AH1 N D R IH0 D
pounds	P AW1 N D Z
a	AH0
year	Y IH1 R
would	W UH1 D
make	M EY1 K
them	DH AH0 M
all	AH0 L
comfortable	K AH1 M F ER0 T AH0 B AH0 L


perhaps	P ER0 HH AE1 P S
it	IH0 T
would	W UH1 D
have	HH AE1 V
been	B IY0 N
as	AH0 S
well	W EH1 L
if	IH1 F
he	HH IY1
had	HH AE1 D
left	L EH1 F T
it	IH1 T
wholly	HH OW1 L IY0
to	T OW0
myself	M AY2 S EH1 L F


a	AH0
pleasant	P L EH1 Z AH0 N T
renewing	R IH0 N UW1 IH0 NG
of	AA1 F
old	OW1 L D
acquaintances	AH0 K W EY1 N T AH0 N S IH0 Z
that	DH AE1 T
was	W AA1 Z
all	AH0 L
i	AY1
had	HH AE1 D
thought	TH AO1 T
it	IH0 T
not	N AA1 T
foreseeing	F AO0 R S IY1 IH0 NG
that	DH AE1 T
I	AY1
was	W AA1 Z
shortly	SH AO1 R T L IY0
to	T OW0
plunge	P L AH1 N JH
into	IH0 N T AH0
this	DH IH0 S
whole	HH OW1 L
situation	S IH2 CH UW0 EY1 SH AH0 N"""

In [22]:
def to_phones(PHONEMES_TO_IDS, phones):
    phonemes_ids = []
    phonemes = []
    phonemes_ids.append(PHONEMES_TO_IDS[""])
    for line in phones.split("\n"):
        if not line:
            continue
        word, word_to_phones = line.rstrip().split("\t", 1)
        if word in default:
            word_to_phones = default[word]
        phonemes.extend(word_to_phones.split(" "))
        phonemes_ids.extend(
            [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
        )
    phonemes_ids.append(PHONEMES_TO_IDS[""])
    return phonemes_ids


In [23]:
phonemes_list = []
with open(checkpoint_path / "feature"/ "phonemes.json") as f:
    phonemes_to_ids = json.load(f)
for hp in huawei_phones.split("\n\n"):
    phoneme_ids = to_phones(phonemes_to_ids, hp)
    phonemes_list.append(phoneme_ids)

In [24]:
feature_model = torch.load(checkpoint_path / "feature" / "feature_model.pth", map_location=device)

In [25]:
feature_model = feature_model.eval()

In [26]:
def get_tacotron_batch(
    phonemes_ids, reference, speaker_id, device, mels_mean, mels_std
):
    text_lengths_tensor = torch.LongTensor([len(phonemes_ids)])
    reference = (reference - mels_mean) / mels_std
    reference = reference.permute(0, 2, 1).to(device)
    phonemes_ids_tensor = torch.LongTensor(phonemes_ids).unsqueeze(0).to(device)
    speaker_ids_tensor = torch.LongTensor([speaker_id]).to(device)
    return phonemes_ids_tensor, text_lengths_tensor, speaker_ids_tensor, reference

In [27]:
reference_pathes = Path("references/")

In [28]:
generated_path = Path(f"generated_hifi/{config.checkpoint_name}")

In [29]:
with open(checkpoint_path / "feature"/ "speakers.json") as f:
    speaker_to_id = json.load(f)

In [30]:
mels_mean = torch.load(checkpoint_path / "feature" / "mels_mean.pth").float()
mels_std = torch.load(checkpoint_path / "feature" / "mels_std.pth").float()

In [31]:
for reference in tqdm(list(reference_pathes.rglob("*.pkl"))):
    speaker = reference.parent.name
    speaker_id = speaker_to_id[speaker]
    ref_mel = torch.load(reference)
    for i, phonemes in enumerate(phonemes_list):
        batch = get_tacotron_batch(phonemes, ref_mel, speaker_id, device, mels_mean, mels_std)
        with torch.no_grad():
            mels = feature_model.inference(batch)
            mels = mels.permute(0, 2, 1).squeeze(0)
            mels = mels * mels_std.to(device) + mels_mean.to(device)
            x = mels.unsqueeze(0)
        for generator_path in generators:
            state_dict = torch.load(generator_path, map_location=device)
            generator = Generator(config=config.train_hifi.model_param, num_mels=config.n_mels).to(device)
            generator.load_state_dict(state_dict["generator"])
            generator.remove_weight_norm()
            generator.eval()
            y_g_hat = generator(x)
            audio = y_g_hat.squeeze()
            audio = audio * 32768
            audio = audio.type(torch.int16).detach().cpu().numpy()
            save_path = generated_path / generator_path.stem / speaker / reference.stem
            save_path.mkdir(exist_ok=True, parents=True)
            print(save_path)
            wav_write(save_path / f"{i + 1}.wav", 22050, audio)


  0%|          | 0/1 [00:00<?, ?it/s]

generated_hifi/esd_vctk_15/g_2515000/0014/neutral
generated_hifi/esd_vctk_15/g_2515000/0014/neutral
generated_hifi/esd_vctk_15/g_2515000/0014/neutral
generated_hifi/esd_vctk_15/g_2515000/0014/neutral
generated_hifi/esd_vctk_15/g_2515000/0014/neutral
generated_hifi/esd_vctk_15/g_2515000/0014/neutral
