In [1]:
import json
import torch
import re
from pathlib import Path
from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config
from src.preprocessing.text.cleaners import english_cleaners
import subprocess
from scipy.io.wavfile import write as wav_write
from tqdm.notebook import tqdm


In [2]:
config = load_config("configs/esd_vctk_15.yml")

In [6]:
device = config.device

In [4]:
checkpoint_path = Path(f"checkpoints/{config.checkpoint_name}")

In [7]:
generators = [file for file in (checkpoint_path / "hifi").rglob("*.*") if file.name.startswith("g_")]

In [8]:
G2P_MODEL_PATH = "models/en/g2p/english_g2p.zip"
G2P_OUTPUT_PATH = "predictions/to_g2p.txt"

In [9]:
def text_to_file(user_query: str) -> None:
    text_path = Path("tmp.txt")
    with open(text_path, "w") as fout:
        normalized_content = english_cleaners(user_query)
        normalized_content = " ".join(re.findall("[a-zA-Z]+", normalized_content))
        fout.write(normalized_content)
    subprocess.call(
        ["mfa", "g2p", G2P_MODEL_PATH, text_path.absolute(), G2P_OUTPUT_PATH]
    )
    text_path.unlink()

In [12]:
default = {"he": "HH IY1", "she": "SH IY1", "we": "W IY1", "be": "B IY0", "the": "DH AH0", "whenever": "W EH0 N EH1 V ER0"}

def parse_g2p(PHONEMES_TO_IDS, g2p_path: str = G2P_OUTPUT_PATH):
    with open(g2p_path, "r") as fin:
        phonemes_ids = []
        phonemes = []
        for line in fin:
            word, word_to_phones = line.rstrip().split("\t", 1)
            if word in default:
                word_to_phones = default[word]
            phonemes.extend(word_to_phones.split(" "))
            phonemes_ids.extend(
                [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
            )
        phonemes_ids.append(PHONEMES_TO_IDS["<PAD>"])
    return phonemes_ids

In [13]:
texts = [
    "Something must be done for them whenever they leave their current place and settle in a new home.",
    "He then really thought himself equal to it.",
    "I had felt that the best of life was over for me!",
    "I hope a hundred pounds a year would make them all comfortable.",
    "Perhaps it would have been as well if he had left it wholly to myself.",
    "A pleasant renewing of old acquaintances, that was all I had thought it, not foreseeing that I was shortly to plunge into this whole situation."
]

In [14]:
phonemes_list = []
with open(checkpoint_path / "feature"/ "phonemes.json") as f:
    phonemes_to_ids = json.load(f)
for t in texts:
    text_to_file(t)
    phoneme_ids = parse_g2p(phonemes_to_ids)
    phonemes_list.append(phoneme_ids)

[32mINFO[0m - Generating transcriptions for the 18 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 8 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 12 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 12 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 15 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 25 word types found in the corpus...
Generating pronunciations...


In [15]:
feature_model = torch.load(checkpoint_path / "feature" / "feature_model.pth", map_location=device)

In [16]:
feature_model = feature_model.eval()

In [17]:
def get_tacotron_batch(
    phonemes_ids, reference, speaker_id, device, mels_mean, mels_std
):
    text_lengths_tensor = torch.LongTensor([len(phonemes_ids)])
    reference = (reference - mels_mean) / mels_std
    reference = reference.permute(0, 2, 1).to(device)
    phonemes_ids_tensor = torch.LongTensor(phonemes_ids).unsqueeze(0).to(device)
    speaker_ids_tensor = torch.LongTensor([speaker_id]).to(device)
    return phonemes_ids_tensor, text_lengths_tensor, speaker_ids_tensor, reference

In [18]:
reference_pathes = Path("references/")

In [19]:
generated_path = Path(f"generated_hifi/{config.checkpoint_name}")

In [20]:
with open(checkpoint_path / "feature"/ "speakers.json") as f:
    speaker_to_id = json.load(f)

In [21]:
mels_mean = torch.load(checkpoint_path / "feature" / "mels_mean.pth").float()
mels_std = torch.load(checkpoint_path / "feature" / "mels_std.pth").float()

In [23]:
for reference in tqdm(list(reference_pathes.rglob("*.pkl"))):
    speaker = reference.parent.name
    speaker_id = speaker_to_id[speaker]
    ref_mel = torch.load(reference)
    for i, phonemes in enumerate(phonemes_list):
        batch = get_tacotron_batch(phonemes, ref_mel, speaker_id, device, mels_mean, mels_std)
        with torch.no_grad():
            mels = feature_model.inference(batch)
            mels = mels.permute(0, 2, 1).squeeze(0)
            mels = mels * mels_std.to(device) + mels_mean.to(device)
            x = mels.unsqueeze(0)
        for generator_path in generators:
            state_dict = torch.load(generator_path, map_location=device)
            generator = Generator(config=config.train_hifi.model_param, num_mels=config.n_mels).to(device)
            generator.load_state_dict(state_dict["generator"])
            generator.remove_weight_norm()
            generator.eval()
            y_g_hat = generator(x)
            audio = y_g_hat.squeeze()
            audio = audio * 32768
            audio = audio.type(torch.int16).detach().cpu().numpy()
            save_path = generated_path / generator_path.stem / speaker / reference.stem
            save_path.mkdir(exist_ok=True, parents=True)
            wav_write(save_path / f"{i + 1}.wav", 22050, audio)

  0%|          | 0/45 [00:00<?, ?it/s]

KeyboardInterrupt: 