# convert autovc output mel spec to wav

In [11]:
import os
import pickle
import torch
import numpy as np
from math import ceil
from model_vc import Generator
from model_vae import VAE
from tqdm import tqdm
from hparams import DATA_PATH
import random

EVAL_SCHEME_LIST = ['TRAINxTRAIN', 'TRAINxTEST', 'TESTxTEST', 'TESTxTRAIN', 'TRAINxVOX', 'TESTxVOX', 'TRAINxPSEUDO', 'TESTxPSEUDO']
EVAL_SCHEME = EVAL_SCHEME_LIST[-1]
EVAL_MODEL_LIST = ['default', 'enhanced_content', 'enhanced_content_emb', 'enhanced_content_emb_l1_vae', 'enhanced_content_l1_vae', 'test']
EVAL_MODEL = EVAL_MODEL_LIST[-3]

# # default model
if EVAL_MODEL == 'default':
    model_path = "/home/yrb/code/ID-DEID/data/model/freq32/autovc_epoch_95_loss_0.0026.ckpt"
    output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/default/'

# content_loss
if EVAL_MODEL == 'enhanced_content':
    model_path = "/home/yrb/code/ID-DEID/data/model/enhanced_freq32/autovc_epoch_98_loss_0.0350.ckpt"
    output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/enhanced_content/'

# content+emb_loss
if EVAL_MODEL == 'enhanced_content_emb':
    model_path = "/home/yrb/code/ID-DEID/data/model/enhanced_freq32+emb_loss/autovc_epoch_97_loss_0.0058.ckpt"
    output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/enhanced_content_emb/'

# # content_loss + mse_vae
# if EVAL_MODEL == 'enhanced_content_mse_vae':
#     model_path = "/home/yrb/code/ID-DEID/data/model/enhanced_freq32/autovc_epoch_98_loss_0.0350.ckpt"
#     output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/enhanced_content_mse_vae/'
#     vae_path = "/home/yrb/code/ID-DEID/data/vae_model/vae_wo_cos_on_voxceleb.ckpt"

# content_loss + l1_vae
if EVAL_MODEL == 'enhanced_content_l1_vae':
    model_path = "/home/yrb/code/ID-DEID/data/model/enhanced_freq32/autovc_epoch_98_loss_0.0350.ckpt"
    output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/enhanced_content_l1_vae/'
    vae_path = "/home/yrb/code/ID-DEID/data/vae_model/l1_vae_on_voxceleb_finetune_on_wsj.ckpt"

# content+emb_loss + l1_vae
if EVAL_MODEL == 'enhanced_content_emb_l1_vae':
    model_path = "/home/yrb/code/ID-DEID/data/model/enhanced_freq32+emb_loss/autovc_epoch_97_loss_0.0058.ckpt"
    output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/enhanced_content_emb_l1_vae/'
    vae_path = "/home/yrb/code/ID-DEID/data/vae_model/l1_vae_on_voxceleb_finetune_on_wsj.ckpt"

# content_loss
if EVAL_MODEL == 'test':
    model_path = "/home/yrb/code/ID-DEID/data/model/autovc_best.ckpt"
    output_path = f'/home/yrb/data/ID-DEID_data/autovc_poster_infer/{EVAL_SCHEME}/full/test/'

torch.set_num_threads(4)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

G = Generator(32,256,512,32).eval().to(device)
g_checkpoint = torch.load(model_path, map_location=device)
G.load_state_dict(g_checkpoint['model'])
print('epoch', int((g_checkpoint['iter']+1)/5000))

if 'PSEUDO' in EVAL_SCHEME:
    Vae = VAE(256, 384, 64, device).to(device)
    Vae_checkpoint = torch.load(vae_path, map_location=device)
    Vae.load_state_dict(Vae_checkpoint['model'])

metadata_wsj = pickle.load(open(DATA_PATH+'wsj_spmel/train.pkl', "rb"))
metadata_vctk = pickle.load(open(DATA_PATH+'vctk_spmel/train.pkl', "rb"))
metadata_vox = pickle.load(open(DATA_PATH+'voxceleb_spmel/train.pkl', "rb"))

# only run portion of meta
metadata = []
src_utter_id_list = []
src_list = []
trg_list = []
if EVAL_SCHEME == 'TRAINxTRAIN':
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i])
        src_list.append(metadata_wsj[sbmt_i][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i]))]))
        
        metadata.append(metadata_wsj[sbmt_i+100])
        trg_list.append(metadata_wsj[sbmt_i+100][0])
        src_utter_id_list.append(None) # trgs don't need utters
elif EVAL_SCHEME == 'TRAINxTEST':
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i])
        src_list.append(metadata_wsj[sbmt_i][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i]))]))
        
        metadata.append(metadata_wsj[sbmt_i+300])
        trg_list.append(metadata_wsj[sbmt_i+300][0])
        src_utter_id_list.append(None)
elif EVAL_SCHEME == 'TESTxTEST':
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i+300])
        src_list.append(metadata_wsj[sbmt_i+300][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i+300]))]))
        
        metadata.append(metadata_wsj[sbmt_i+301])
        trg_list.append(metadata_wsj[sbmt_i+301][0])
        src_utter_id_list.append(None) # trgs don't need utters
elif EVAL_SCHEME == 'TESTxTRAIN':
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i+300])
        src_list.append(metadata_wsj[sbmt_i+300][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i+300]))]))

        metadata.append(metadata_wsj[sbmt_i])
        trg_list.append(metadata_wsj[sbmt_i][0])
        src_utter_id_list.append(None)
elif EVAL_SCHEME == 'TRAINxVOX':
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i])
        src_list.append(metadata_wsj[sbmt_i][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i]))]))
        
        metadata.append(metadata_vox[sbmt_i])
        trg_list.append(metadata_vox[sbmt_i][0])
        src_utter_id_list.append(None) # trgs don't need utters
elif EVAL_SCHEME == 'TRAINxPSEUDO':
    noise = torch.randn(100, 64).to(device)
    generated_emb = Vae.Decoder(noise)
    generated_emb = generated_emb.cpu().detach().numpy()
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i])
        src_list.append(metadata_wsj[sbmt_i][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i]))]))
        
        fake_meta = [f'pseudo_{sbmt_i}', generated_emb[sbmt_i], None]
        metadata.append(fake_meta)
        trg_list.append(fake_meta[0])
        src_utter_id_list.append(None) # trgs don't need utters
elif EVAL_SCHEME == 'TESTxPSEUDO':
    noise = torch.randn(100, 64).to(device)
    generated_emb = Vae.Decoder(noise)
    generated_emb = generated_emb.cpu().detach().numpy()
    for sbmt_i in range(100):
        metadata.append(metadata_wsj[sbmt_i+300])
        src_list.append(metadata_wsj[sbmt_i+300][0])
        src_utter_id_list.append(random.choice([i for i in range(3, len(metadata_wsj[sbmt_i+300]))]))
        
        fake_meta = [f'pseudo_{sbmt_i}', generated_emb[sbmt_i], None]
        metadata.append(fake_meta)
        trg_list.append(fake_meta[0])
        src_utter_id_list.append(None) # trgs don't need utters
else:
    raise NotImplementedError

epoch 97


In [12]:
def pad_seq(x, base=32):
    len_out = int(base * ceil(float(x.shape[0])/base))
    len_pad = len_out - x.shape[0]
    assert len_pad >= 0
    return np.pad(x, ((0,len_pad),(0,0)), 'constant'), len_pad

def determine_dataset(spkrid):
    if spkrid == 's5' or (len(spkrid) == 4 and spkrid.startswith('p')):
        return 'vctk'
    elif len(spkrid) == 3:
        return 'wsj'
    elif len(spkrid) == 7 and spkrid.startswith('id'):
        return 'vox'
    else:
        raise NotImplementedError(f'invalid spkrid: {spkrid}')

def get_wavs_and_spmel_dir(spkrid):
    dataset_name = determine_dataset(spkrid)
    if dataset_name == 'wsj':
        return os.path.join(DATA_PATH, 'wsj_wavs'), os.path.join(DATA_PATH, 'wsj_spmel')
    elif dataset_name == 'vctk':
        return os.path.join(DATA_PATH, 'vctk_wavs'), os.path.join(DATA_PATH, 'vctk_spmel')
    elif dataset_name == 'vox':
        return os.path.join(DATA_PATH, 'voxceleb_wavs'), os.path.join(DATA_PATH, 'voxceleb_spmel')
    else:
        raise NotImplementedError(f'invalid dataset_name: {dataset_name}')

src_utter_id_list_path = os.path.join(output_path, '../utter_id_list.pkl')
os.makedirs(os.path.dirname(src_utter_id_list_path), exist_ok=True)
if os.path.exists(src_utter_id_list_path):
    with open(src_utter_id_list_path, 'rb') as ff:
        src_utter_id_list = pickle.load(ff)
else:
    with open(src_utter_id_list_path, 'wb') as ff:
        pickle.dump(src_utter_id_list, ff)

spect_vc = []
cropping = True

for sbmt_i_idx, sbmt_i in enumerate(tqdm(metadata)):
    if sbmt_i[0] not in src_list: continue
    if src_utter_id_list[sbmt_i_idx] == None: continue

    # get original utterance name, wavs dir and spmel dir
    wavs_dir, spmel_dir = get_wavs_and_spmel_dir(sbmt_i[0])
    x_org = sbmt_i[src_utter_id_list[sbmt_i_idx]]

    # copy original wav
    src_wav_path = os.path.join(wavs_dir, x_org.replace('.npy', '.wav'))
    cp_wav_path = os.path.join(output_path, 'src')
    os.makedirs(cp_wav_path, exist_ok=True)
    os.system(f"cp {src_wav_path} {cp_wav_path}")

    # get original utterance spmel and embedding
    x_org = np.load(os.path.join(spmel_dir, x_org))
    x_org, len_pad = pad_seq(x_org, base=128)
    start = 0 if cropping else 0
    uttr_org = torch.from_numpy(x_org[np.newaxis, start:, :]).to(device)
    uttr_org = uttr_org.reshape(-1, 128, 80) if cropping else uttr_org
    emb_org = torch.from_numpy(sbmt_i[1][np.newaxis, :]).repeat(uttr_org.shape[0],1).to(device)
    
    for sbmt_j_idx, sbmt_j in enumerate(metadata):
        if sbmt_j[0] not in trg_list: continue
        if src_list.index(sbmt_i[0]) != trg_list.index(sbmt_j[0]): continue

        # copy target wav
        if 'pseudo' not in sbmt_j[0]:
            wavs_dir, spmel_dir = get_wavs_and_spmel_dir(sbmt_j[0])
            trg_utter = sbmt_j[2]
            trg_wav_path = os.path.join(wavs_dir, trg_utter.replace('.npy', '.wav'))
            cp_wav_path = os.path.join(output_path, 'trg')
            os.makedirs(cp_wav_path, exist_ok=True)
            new_trg_wav_name = trg_utter.replace('.npy', '.wav').replace('/', '_')
            os.system(f"cp {trg_wav_path} {cp_wav_path}/{new_trg_wav_name}")
            
        emb_trg = torch.from_numpy(sbmt_j[1][np.newaxis, :]).repeat(uttr_org.shape[0],1).to(device)
        
        with torch.no_grad():
            _, x_identic_psnt, _ = G(uttr_org, emb_org, emb_trg) # TODO: emb_trg
        x_identic_psnt = x_identic_psnt.reshape(1, 1, -1, 80)
            
        if len_pad == 0:
            uttr_trg = x_identic_psnt[0, 0, :, :].cpu().numpy()
        else:
            uttr_trg = x_identic_psnt[0, 0, :-len_pad, :].cpu().numpy()
        
        spect_vc.append( ('{}x{}'.format(sbmt_i[0], sbmt_j[0]), uttr_trg) )
    with open(os.path.join(output_path, 'results.pkl'), 'wb') as handle:
        pickle.dump(spect_vc, handle) 
    

100%|██████████| 200/200 [00:09<00:00, 20.50it/s]


# Hifi-GAN infer

In [13]:
import os
from hparams import DATA_PATH
command = "python /home/yx/hifi-gan/inference_e2e.py "
# Add pickle path here
command += "--input_pkl_dir "
command += os.path.join(output_path, 'results.pkl') + " "
# Add output path here
command += "--output_wavs_dir "
command += output_path
print(command)
os.system(command)

python /home/yx/hifi-gan/inference_e2e.py --input_pkl_dir /home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/results.pkl --output_wavs_dir /home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/
Initializing Inference Process..
Loading '/home/yx/hifi-gan/pretrained/VCTK_V1/g_03280000'
Complete.
Removing weight norm...
/home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/49gxpseudo_0.wav
/home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/49hxpseudo_1.wav
/home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/49ixpseudo_2.wav
/home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/49jxpseudo_3.wav
/home/yrb/data/ID-DEID_data/autovc_poster_infer/TESTxPSEUDO/full/enhanced_content_emb_l1_vae/49kxpseudo_4.wav
/home/yrb/data/ID-DEID_data/autovc_poster_infer/TE

0