In [1]:
import os
import torch
import pickle
import numpy as np
import soundfile

from hparams import hparams
from utils import pad_seq_to_2, quantize_f0_numpy
from model import Generator_3 as Generator
from synthesis import build_model, wavegen

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# -------------------------
# Load Demo Metadata (shared across checkpoints)
# -------------------------
metadata_path = 'assets/demo_m2f.pkl'
with open(metadata_path, "rb") as f:
    metadata = pickle.load(f)

# -------------------------
# Load Waveform Synthesis Model (shared across checkpoints)
# -------------------------
synth_model = build_model().to(device)
synth_ckpt = torch.load("assets/checkpoint_step001000000_ema.pth", map_location=torch.device(device))
synth_model.load_state_dict(synth_ckpt["state_dict"])

# -------------------------
# List of Generator checkpoint names to process
# (without extension, e.g., "800000-G-B10")
# -------------------------
generator_ckpts = ["774000-G"]  # update this list as needed

# -------------------------
# Main processing loop
# -------------------------
for ckpt_name in generator_ckpts:
    print(f"Processing Generator checkpoint: {ckpt_name}")
    
    # Create a dedicated results directory for this checkpoint
    results_dir = os.path.join("results", ckpt_name)
    os.makedirs(results_dir, exist_ok=True)
    
    # -------------------------
    # Load Generator Model & Checkpoint
    # -------------------------
    G = Generator(hparams).eval().to(device)
    g_ckpt_path = os.path.join("assets", f"{ckpt_name}.ckpt")
    g_checkpoint = torch.load(g_ckpt_path, map_location=lambda storage, loc: storage)
    G.load_state_dict(g_checkpoint['model'], strict=False)
    
    # -------------------------
    # Process Source Utterance (sbmt_i)
    # -------------------------
    sbmt_i = metadata[0]
    emb_org = torch.from_numpy(sbmt_i[1]).to(device)
    if emb_org.dim() == 1:
        emb_org = emb_org.unsqueeze(0)

    x_org, f0_org, len_org, uid_org = sbmt_i[2]
    # Crop x_org (and f0_org for 1D case) to 192 frames if needed
    if x_org.shape[0] > 192:
        x_org = x_org[:192, :]
        len_org = 192
    if f0_org.ndim == 1 and f0_org.shape[0] > 192:
        f0_org = f0_org[:192]
        
    uttr_org_pad, _ = pad_seq_to_2(x_org[np.newaxis, :, :], 192)
    uttr_org_pad = torch.from_numpy(uttr_org_pad).to(device)

    if f0_org.ndim == 1:
        f0_org_pad = np.pad(f0_org, (0, 192 - len_org), 'constant', constant_values=(0, 0))
        f0_org_quantized = quantize_f0_numpy(f0_org_pad)[0]
    else:
        if f0_org.shape[0] < 192:
            f0_org_quantized = np.pad(f0_org, ((0, 192 - f0_org.shape[0]), (0, 0)),
                                       'constant', constant_values=(0, 0))
        else:
            f0_org_quantized = f0_org[:192, :]
    f0_org_onehot = f0_org_quantized[np.newaxis, :, :]
    f0_org_onehot = torch.from_numpy(f0_org_onehot).to(device)
    uttr_f0_org = torch.cat((uttr_org_pad, f0_org_onehot), dim=-1)
    
    # -------------------------
    # Process Target Utterance (sbmt_j)
    # -------------------------
    sbmt_j = metadata[1]
    emb_trg = torch.from_numpy(sbmt_j[1]).to(device)
    if emb_trg.dim() == 1:
        emb_trg = emb_trg.unsqueeze(0)
        
    x_trg, f0_trg, len_trg, uid_trg = sbmt_j[2]
    if x_trg.shape[0] > 192:
        x_trg = x_trg[:192, :]
        len_trg = 192
    if f0_trg.ndim == 1 and f0_trg.shape[0] > 192:
        f0_trg = f0_trg[:192]
        
    uttr_trg_pad, _ = pad_seq_to_2(x_trg[np.newaxis, :, :], 192)
    uttr_trg_pad = torch.from_numpy(uttr_trg_pad).to(device)

    if f0_trg.ndim == 1:
        f0_trg_pad = np.pad(f0_trg, (0, 192 - len_trg), 'constant', constant_values=(0, 0))
        f0_trg_quantized = quantize_f0_numpy(f0_trg_pad)[0]
    else:
        if f0_trg.shape[0] < 192:
            f0_trg_quantized = np.pad(f0_trg, ((0, 192 - f0_trg.shape[0]), (0, 0)),
                                       'constant', constant_values=(0, 0))
        else:
            f0_trg_quantized = f0_trg[:192, :]
    f0_trg_onehot = f0_trg_quantized[np.newaxis, :, :]
    f0_trg_onehot = torch.from_numpy(f0_trg_onehot).to(device)
    
    # Instead of using an F0 converter, directly create the target F0 input:
    uttr_f0_trg = torch.cat((uttr_trg_pad, f0_trg_onehot), dim=-1)
    
    # -------------------------
    # Run Generator Under Different Conditions
    # -------------------------
    conditions = ['R', 'F', 'U', 'FU', 'RF', 'RU', 'RFU']
    spect_vc = []
    with torch.no_grad():
        for condition in conditions:
            if condition == 'R':
                x_identic_val, var, mu = G(uttr_f0_org, uttr_trg_pad, emb_org)
            elif condition == 'F':
                x_identic_val, var, mu = G(uttr_f0_trg, uttr_org_pad, emb_org)
            elif condition == 'U':
                x_identic_val, var, mu = G(uttr_f0_org, uttr_org_pad, emb_trg)
            elif condition == 'RF':
                x_identic_val, var, mu = G(uttr_f0_trg, uttr_trg_pad, emb_org)
            elif condition == 'RU':
                x_identic_val, var, mu = G(uttr_f0_org, uttr_trg_pad, emb_trg)
            elif condition == 'FU':
                x_identic_val, var, mu = G(uttr_f0_trg, uttr_org_pad, emb_trg)
            elif condition == 'RFU':
                x_identic_val, var, mu = G(uttr_f0_trg, uttr_trg_pad, emb_trg)
            
            # Choose output length: if the condition contains 'R', use target length; otherwise, source length.
            if 'R' in condition:
                uttr_trg_out = x_identic_val[0, :len_trg, :].cpu().numpy()
            else:
                uttr_trg_out = x_identic_val[0, :len_org, :].cpu().numpy()
            
            spect_name = f"{sbmt_i[0]}_{sbmt_j[0]}_{uid_org}_{condition}"
            spect_vc.append((spect_name, uttr_trg_out))
    
    print("Generated spectrograms for conditions:", [name for name, _ in spect_vc])
    
    # -------------------------
    # Spectrogram-to-Waveform Conversion & Saving
    # -------------------------
    for spect in spect_vc:
        name, spect_data = spect
        print(f"Generating waveform for {name}")
        waveform = wavegen(synth_model, c=spect_data)
        out_path = os.path.join(results_dir, f"{name}.wav")
        soundfile.write(out_path, waveform, samplerate=16000)
    
    print(f"Results saved in directory: {results_dir}\n")



Processing Generator checkpoint: 774000-G
Generated spectrograms for conditions: ['p558_p547_p558_001.npy_R', 'p558_p547_p558_001.npy_F', 'p558_p547_p558_001.npy_U', 'p558_p547_p558_001.npy_FU', 'p558_p547_p558_001.npy_RF', 'p558_p547_p558_001.npy_RU', 'p558_p547_p558_001.npy_RFU']
Generating waveform for p558_p547_p558_001.npy_R


100%|█████████████████████████████████████████████████████████████████████████████████████████| 30976/30976 [02:56<00:00, 175.46it/s]


Generating waveform for p558_p547_p558_001.npy_F


100%|█████████████████████████████████████████████████████████████████████████████████████████| 34560/34560 [03:17<00:00, 175.38it/s]


Generating waveform for p558_p547_p558_001.npy_U


100%|█████████████████████████████████████████████████████████████████████████████████████████| 34560/34560 [03:16<00:00, 175.97it/s]


Generating waveform for p558_p547_p558_001.npy_FU


100%|█████████████████████████████████████████████████████████████████████████████████████████| 34560/34560 [03:18<00:00, 173.71it/s]


Generating waveform for p558_p547_p558_001.npy_RF


100%|█████████████████████████████████████████████████████████████████████████████████████████| 30976/30976 [02:56<00:00, 175.85it/s]


Generating waveform for p558_p547_p558_001.npy_RU


100%|█████████████████████████████████████████████████████████████████████████████████████████| 30976/30976 [02:56<00:00, 175.80it/s]


Generating waveform for p558_p547_p558_001.npy_RFU


100%|█████████████████████████████████████████████████████████████████████████████████████████| 30976/30976 [02:57<00:00, 174.51it/s]

Results saved in directory: results/774000-G




