In [103]:
import json
import torch
import re
import mlflow
from pathlib import Path
from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config
from src.preprocessing.text.cleaners import english_cleaners
import subprocess
from scipy.io.wavfile import write as wav_write

In [37]:
config = load_config("configs/esd_vctk.yml")

In [38]:
device=torch.device("cuda:1")

In [61]:
G2P_MODEL_PATH = "models/g2p/english_g2p.zip"
G2P_OUTPUT_PATH = "predictions/to_g2p.txt"

In [39]:
vocoder = load_hifi(
    model_path=config.pretrained_hifi,
    hifi_config=config.train_hifi.model_param,
    num_mels=config.n_mels,
    device=device,
)

In [102]:
vocoder.eval()

Generator(
  (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
  (ups): ModuleList(
    (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
    (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(2,), padding=(1,))
    (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
  )
  (resblocks): ModuleList(
    (0): ResBlock1(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (2): Conv1d(256, 256, kernel_size=(3,), stride

In [68]:
def text_to_file(user_query: str) -> None:
    text_path = Path("tmp.txt")
    with open(text_path, "w") as fout:
        normalized_content = english_cleaners(user_query)
        normalized_content = " ".join(re.findall("[a-zA-Z]+", normalized_content))
        fout.write(normalized_content)
    subprocess.call(
        ["mfa", "g2p", G2P_MODEL_PATH, text_path.absolute(), G2P_OUTPUT_PATH]
    )
    text_path.unlink()

In [65]:
def parse_g2p(PHONEMES_TO_IDS, g2p_path: str = G2P_OUTPUT_PATH):
    with open(g2p_path, "r") as fin:
        phonemes_ids = []
        phonemes = []
        for line in fin:
            _, word_to_phones = line.rstrip().split("\t", 1)
            phonemes.extend(word_to_phones.split(" "))
            phonemes_ids.extend(
                [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
            )
        phonemes_ids.append(PHONEMES_TO_IDS["<PAD>"])
    return phonemes_ids

In [47]:
checkpoints_path = Path("checkpoints/")

In [10]:
path = Path("mlruns")

In [40]:
texts = [
    "I can't believe he did it!",
    "He has abandoned all the traditions here.",
    "I mean, can you imagine anything more inappropriate?",
    "That's what the crowd never expected, all of us were astonished to see it.",
    "Now he wonders why people think he is a bit odd.",
    "That is insane!"
]

In [127]:
exp_to_phonemes = {}
for i in [3, 7, 15]:
    exp_id = mlflow.get_experiment_by_name(f"esd_vctk_{i}").experiment_id
    with open(f"checkpoints/esd_vctk_{i}/feature/phonemes.json") as f:
        phonemes_to_ids = json.load(f)
    phonemes_list = []
    for t in texts:
        text_to_file(t)
        phoneme_ids = parse_g2p(phonemes_to_ids)
        phonemes_list.append(phoneme_ids)
    exp_to_phonemes[exp_id] = phonemes_list

[32mINFO[0m - Generating transcriptions for the 7 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 7 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 8 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 15 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 11 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 3 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 7 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 7 word types found in the corpus...
Generating pronunciations...
[32mINFO[0m - Generating transcriptions for the 8 word types found in the co

In [87]:
exp_to_speakers = {}
for i in [3, 7, 15]:
    exp_id = mlflow.get_experiment_by_name(f"esd_vctk_{i}").experiment_id
    with open(f"checkpoints/esd_vctk_{i}/feature/speakers.json") as f:
        speaker_to_ids = json.load(f)
    exp_to_speakers[exp_id] = {k: v for k, v in speaker_to_ids.items() if not k.startswith("p")}

In [97]:
exp_id_to_name = {}
for i in [3, 7, 15]:
    exp_id = mlflow.get_experiment_by_name(f"esd_vctk_{i}").experiment_id
    exp_id_to_name[exp_id] = f"esd_vctk_{i}"

In [98]:
exp_id_to_name

{'2': 'esd_vctk_3', '1': 'esd_vctk_7', '3': 'esd_vctk_15'}

In [134]:
!ls -lr mlruns/1/00d98c4d3b1448dd8270eee06cf509e1/artifacts/checkpoints/esd_vctk_7/feature/feature_model.pth/data/

/bin/bash: /home/uadmin/miniconda3/envs/emotts/lib/libtinfo.so.6: no version information available (required by /bin/bash)
total 304628
-rw-rw-r-- 1 uadmin uadmin        28 Jan 30 11:20 pickle_module_info.txt
-rw-rw-r-- 1 uadmin uadmin 311934937 Jan 30 11:20 model.pth


In [76]:
p = list(path.rglob("model.pth"))

In [84]:
exp_index = 7
run_index = 6

In [141]:
def get_tacotron_batch(
    phonemes_ids, reference, speaker_id, device, mels_mean, mels_std
):
    text_lengths_tensor = torch.LongTensor([len(phonemes_ids)])
    reference = (reference - mels_mean) / mels_std
    reference = reference.permute(0, 2, 1).to(device)
    phonemes_ids_tensor = torch.LongTensor(phonemes_ids).unsqueeze(0).to(device)
    speaker_ids_tensor = torch.LongTensor([speaker_id]).to(device)
    return phonemes_ids_tensor, text_lengths_tensor, speaker_ids_tensor, reference

In [91]:
reference_pathes = Path("references/")

In [93]:
list((reference_pathes / "0011").rglob("*.pkl"))[0].stem

'happy'

In [110]:
generated_path = Path("generated_custom")

In [142]:
for mp in p:
    model = torch.load(mp)
    model.to(device)
    model.eval()
    exp_id = list(mp.parents)[exp_index].name
    run_id = list(mp.parents)[run_index].name
    speakers = exp_to_speakers[exp_id]
    exp_name = exp_id_to_name[exp_id]
    mels_mean = torch.load(f"checkpoints/{exp_name}/feature/mels_mean.pth").to("cpu").float()
    mels_std = torch.load(f"checkpoints/{exp_name}/feature/mels_std.pth").to("cpu").float()
    for speaker, speaker_id in speakers.items():
        for reference in (reference_pathes / speaker).rglob("*.pkl"):
            emo = reference.stem
            ref_mel = torch.load(reference)
            save_path = generated_path / exp_id / run_id / speaker / emo
            save_path.mkdir(exist_ok=True, parents=True)
            for i, phonemes_ids in enumerate(exp_to_phonemes[exp_id]):
                batch = get_tacotron_batch(phonemes_ids, ref_mel, speaker_id, device, mels_mean, mels_std)
                with torch.no_grad():
                    mels = model.inference(batch)
                    mels = mels.permute(0, 2, 1).squeeze(0)
                    mels = mels * mels_std.to(device) + mels_mean.to(device)
                    x = mels.unsqueeze(0)
                    y_g_hat = vocoder(x)
                    audio = y_g_hat.squeeze()
                    audio = audio * 32768
                    audio = audio.type(torch.int16).detach().cpu().numpy()
                    wav_write(save_path / f"{i + 1}.wav", 22050, audio)

KeyError: '4'