In [1]:
import json
import torch
from pathlib import Path

from scipy.io.wavfile import write as wav_write
from tqdm.notebook import tqdm

from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config

In [2]:
config = load_config("configs/esd_tune.yml")

In [3]:
device = "cuda:0"  # config.device

In [4]:
checkpoint_path = Path(f"checkpoints/{config.checkpoint_name}")

In [5]:
generators = [file for file in (checkpoint_path / "hifi").rglob("*.*") if file.name.startswith("g_")]

In [6]:
reference_pathes = Path("references/")

In [7]:
generated_path = Path(f"generated_refs/{config.checkpoint_name}")

In [8]:
with open(checkpoint_path / "feature"/ "speakers.json") as f:
    speaker_to_id = json.load(f)

In [9]:
mels_mean = torch.load(checkpoint_path / "feature" / "mels_mean.pth", map_location=device).float()
mels_std = torch.load(checkpoint_path / "feature" / "mels_std.pth", map_location=device).float()

In [10]:
for reference in tqdm(list(reference_pathes.rglob("*.pkl"))):
    speaker = reference.parent.name
    emotion = reference.stem
    speaker_id = speaker_to_id[speaker]
    ref_mel = torch.load(reference, map_location=device)
    ref_mel = (ref_mel - mels_mean) / mels_std
    ref_mel = ref_mel.unsqueeze(0)
    ref_mel = ref_mel.permute(0, 2, 1).to(device)
    with torch.no_grad():
        mels = ref_mel
        mels = mels.permute(0, 2, 1).squeeze(0)
        mels = mels * mels_std.to(device) + mels_mean.to(device)
        x = mels.unsqueeze(0)
        for generator_path in generators:
            state_dict = torch.load(generator_path, map_location="cpu")
            state_dict["generator"] = {k: v.to(device) for k, v in state_dict["generator"].items()}
            generator = Generator(config=config.train_hifi.model_param, num_mels=config.n_mels).to(device)
            generator.load_state_dict(state_dict["generator"])
            generator.remove_weight_norm()
            generator.eval()
            y_g_hat = generator(x)
            audio = y_g_hat.squeeze()
            audio = audio * 32768
            audio = audio.type(torch.int16).detach().cpu().numpy()
            save_path = generated_path / generator_path.stem / speaker / reference.stem
            save_path.mkdir(exist_ok=True, parents=True)
            wav_write(save_path / f"{emotion}.wav", 22050, audio)
            torch.cuda.empty_cache()


  0%|          | 0/45 [00:00<?, ?it/s]