In [1]:
import json
import torch
import re
from pathlib import Path
from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config
from src.preprocessing.text.cleaners import english_cleaners
import subprocess
from scipy.io.wavfile import write as wav_write
from tqdm.notebook import tqdm


In [2]:
config = load_config("configs/esd_vctk_15.yml")

In [3]:
device = config.device

In [4]:
checkpoint_path = Path(f"checkpoints/{config.checkpoint_name}")

In [5]:
generators = [file for file in (checkpoint_path / "hifi").rglob("*.*") if file.name.startswith("g_")]

In [6]:
G2P_MODEL_PATH = "models/en/g2p/english_g2p.zip"
G2P_OUTPUT_PATH = "predictions/to_g2p.txt"

In [7]:
def text_to_file(user_query: str) -> None:
    text_path = Path("tmp.txt")
    with open(text_path, "w") as fout:
        normalized_content = english_cleaners(user_query)
        normalized_content = " ".join(re.findall("[a-zA-Z]+", normalized_content))
        fout.write(normalized_content)
    subprocess.call(
        ["mfa", "g2p", G2P_MODEL_PATH, text_path.absolute(), G2P_OUTPUT_PATH]
    )
    text_path.unlink()

In [8]:
default = {"he": "HH IY1", "she": "SH IY1", "we": "W IY1", "be": "B IY0", "the": "DH AH0", "whenever": "W EH0 N EH1 V ER0", "year": "AH0 Y IH1 R"}

def parse_g2p(PHONEMES_TO_IDS, g2p_path: str = G2P_OUTPUT_PATH):
    with open(g2p_path, "r") as fin:
        phonemes_ids = []
        phonemes = []
        phonemes_ids.append(PHONEMES_TO_IDS[""])
        for line in fin:
            word, word_to_phones = line.rstrip().split("\t", 1)
            if word in default:
                word_to_phones = default[word]
            phonemes.extend(word_to_phones.split(" "))
            phonemes_ids.extend(
                [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
            )
        phonemes_ids.append(PHONEMES_TO_IDS[""])
    return phonemes_ids

In [9]:
texts = [
    "Something must be done for them whenever they leave their current place and settle in a new home.",
    "He then really thought himself equal to it.",
    "I had felt that the best of life was over for me!",
    "I hope a hundred pounds a year would make them all comfortable.",
    "Perhaps it would have been as well if he had left it wholly to myself.",
    "A pleasant renewing of old acquaintances, that was all I had thought it, not foreseeing that I was shortly to plunge into this whole situation."
]

In [75]:
huawei_phones = """S AH1 M TH IH0 NG
M AH1 S T
B IY1
D AH1 N
F AO1 R
DH EH1 M
W EH0 N EH1 V ER0
DH EY1
L IY1 V
DH EH1 R
K ER1 AH0 N T
P L EY1 S
AH0 N D
S EH1 T AH0 L
IH0 N
AH0
N UW1
HH OW1 M


HH IY1
DH EH1 N
R IH1 L IY0
TH AO1 T
HH IH0 M S EH1 L F
IY1 K W AH0 L
T UW1
IH1 T


AY1
HH AE1 D
F EH1 L T
DH AE1 T
DH AH0
B EH1 S T
AH1 V
L AY1 F
W AA1 Z
OW1 V ER0
F AO1 R
M IY1


AY1
HH OW1 P
AH0
HH AH1 N D R AH0 D
P AW1 N D Z
AH0
Y IH1 R
W UH1 D
M EY1 K
DH EH1 M
AO1 L
K AH1 M F ER0 T AH0 B AH0 L


P ER0 HH AE1 P S
IH1 T
W UH1 D
HH AE1 V
B IH1 N
AE1 Z
W EH1 L
IH1 F
HH IY1
HH AE1 D
L EH1 F T
IH1 T
HH OW1 L IY0
T UW0
M AY2 S EH1 L F


AH0
P L EH1 Z AH0 N T
R IH0 N UW1 IH0 NG
AH1 V
OW1 L D
AH0 K W EY1 N T AH0 N S IH0 Z
 
DH AE1 T
W AA1 Z
AO1 L
AY1
HH AE1 D
TH AO1 T
IH1 T
 
N AA1 T
F AO0 R S IY1 IH0 NG
DH AE1 T
AY1
W AA1 Z
SH AO1 R T L IY0
T UW0
P L AH1 N JH
IH1 N T UW0
DH IH1 S
HH OW1 L
S IH2 CH UW0 EY1 SH AH0 N"""

In [76]:
def to_phones(PHONEMES_TO_IDS, phones):
    phonemes_ids = []
    phonemes_ids.append(PHONEMES_TO_IDS[""])
    for line in phones.split("\n"):
        if not line:
            continue
        word_to_phones = line
        phonemes_ids.extend(
            [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
        )
    phonemes_ids.append(PHONEMES_TO_IDS[""])
    return phonemes_ids


In [77]:
phonemes_list = []
with open(checkpoint_path / "feature"/ "phonemes.json") as f:
    phonemes_to_ids = json.load(f)
for hp in huawei_phones.split("\n\n\n"):
    phoneme_ids = to_phones(phonemes_to_ids, hp)
    phonemes_list.append(phoneme_ids)

In [78]:
feature_model = torch.load(checkpoint_path / "feature" / "feature_model.pth", map_location=device)

In [79]:
feature_model = feature_model.eval()

In [80]:
def get_tacotron_batch(
    phonemes_ids, reference, speaker_id, device, mels_mean, mels_std
):
    text_lengths_tensor = torch.LongTensor([len(phonemes_ids)])
    reference = (reference - mels_mean) / mels_std
    reference = reference.permute(0, 2, 1).to(device)
    phonemes_ids_tensor = torch.LongTensor(phonemes_ids).unsqueeze(0).to(device)
    speaker_ids_tensor = torch.LongTensor([speaker_id]).to(device)
    return phonemes_ids_tensor, text_lengths_tensor, speaker_ids_tensor, reference

In [81]:
reference_pathes = Path("references/")

In [82]:
generated_path = Path(f"generated_hifi/{config.checkpoint_name}")

In [83]:
with open(checkpoint_path / "feature"/ "speakers.json") as f:
    speaker_to_id = json.load(f)

In [84]:
mels_mean = torch.load(checkpoint_path / "feature" / "mels_mean.pth").float()
mels_std = torch.load(checkpoint_path / "feature" / "mels_std.pth").float()

In [85]:
for reference in tqdm(list(reference_pathes.rglob("*.pkl"))[0:1]):
    speaker = reference.parent.name
    speaker_id = speaker_to_id[speaker]
    ref_mel = torch.load(reference)
    for i, phonemes in enumerate(phonemes_list[4:]):
        batch = get_tacotron_batch(phonemes, ref_mel, speaker_id, device, mels_mean, mels_std)
        with torch.no_grad():
            mels = feature_model.inference(batch)
            mels = mels.permute(0, 2, 1).squeeze(0)
            mels = mels * mels_std.to(device) + mels_mean.to(device)
            x = mels.unsqueeze(0)
        for generator_path in generators[0:1]:
            state_dict = torch.load(generator_path, map_location=device)
            generator = Generator(config=config.train_hifi.model_param, num_mels=config.n_mels).to(device)
            generator.load_state_dict(state_dict["generator"])
            generator.remove_weight_norm()
            generator.eval()
            y_g_hat = generator(x)
            audio = y_g_hat.squeeze()
            audio = audio * 32768
            audio = audio.type(torch.int16).detach().cpu().numpy()
            save_path = generated_path / generator_path.stem / speaker / reference.stem
            save_path.mkdir(exist_ok=True, parents=True)
            wav_write(save_path / f"{i + 1}.wav", 22050, audio)


  0%|          | 0/1 [00:00<?, ?it/s]