In [1]:
import pickle
import librosa
import random
import soundfile as sf
import os
import torch
import numpy as np
import IPython.display as ipd
from util.convert import get_trans_mel,get_mel,convert,save_wave
from factory.AutoVC import AutoVC
from factory.ASGANVC import ASGANVC
from factory.VQVC import VQVC
from factory.AgainVC import AgainVC

In [2]:
device = "cuda:0"

In [3]:
ROOT = "train_spmel_vctk80"
SAVE_DIR = "generate_wav"

In [4]:
speakers = [20,59,54,45,11,35,3,41,12,0]
model = ['autovc','asganvc','vqvc','againvc']
sample_rate = 22050

In [5]:
build_config = {'model_name': 'again', 'model': {'params': {'encoder_params': {'c_in': 80, 'c_h': 256, 'c_out': 4, 'n_conv_blocks': 6, 'subsample': [1, 1, 1, 1, 1, 1]}, 'decoder_params': {'c_in': 4, 'c_h': 256, 'c_out': 80, 'n_conv_blocks': 6, 'upsample': [1, 1, 1, 1, 1, 1]}, 'activation_params': {'act': 'sigmoid', 'params': {'alpha': 0.1}}}}, 'optimizer': {'params': {'lr': 0.0005, 'betas': [0.9, 0.999], 'amsgrad': True, 'weight_decay': 0.0001}, 'grad_norm': 3}}

In [6]:
metadata = pickle.load(open(f'{ROOT}/train.pkl', "rb"))

In [7]:
autovc =  AutoVC(32,256,512,16).to(device)
autovc.load_state_dict(torch.load("model/autovc_128.pt", map_location=device))
asganvc =  ASGANVC(32,256,512,16).to(device)
asganvc.load_state_dict(torch.load("model/asganvc_128.pt", map_location=device))
vqvc = VQVC(80,64,64).to(device)
vqvc.load_state_dict(torch.load("model/vqvc+.pt", map_location=device))
againvc = AgainVC(**build_config['model']['params']).to("cuda:0")
againvc.load_state_dict(torch.load("model/againvc.pt",map_location="cuda:0"))

<All keys matched successfully>

In [8]:
for sp in speakers:
    sound_id = random.randint(3,7) #7
    t_id = random.randint(8,10)
    source_path = metadata[sp][sound_id].replace("\\", "/")
    target_path = metadata[sp][t_id].replace("\\", "/")
    mel_source = np.load(f"{ROOT}/{source_path}")
    mel_target = np.load(f"{ROOT}/{target_path}")
    source_wave = librosa.effects.trim(convert( torch.from_numpy(mel_source).unsqueeze(0)), top_db=20)[0]
    target_wave = librosa.effects.trim(convert( torch.from_numpy(mel_target).unsqueeze(0)), top_db=20)[0]
    save_wave(f'{SAVE_DIR}/target/target_{sp}.wav',target_wave,sample_rate)
    save_wave(f'{SAVE_DIR}/source/source_{sp}.wav',source_wave,sample_rate)

In [9]:
for sp_s in speakers:
    for md in model:
        try:
            os.makedirs(f'{SAVE_DIR}/{md}/{sp}')     
        except:
            pass

In [10]:
for sp_s in speakers:
    sound_id = random.randint(3,13) 
    source_path = metadata[sp_s][sound_id].replace("\\", "/")
    mel_source = np.load(f"{ROOT}/{source_path}")
    emb_org = torch.from_numpy(metadata[sp_s][1]).unsqueeze(0).to(device)
    source_wave = librosa.effects.trim(convert( torch.from_numpy(mel_source).unsqueeze(0)), top_db=20)[0]
    save_wave(f'{SAVE_DIR}/source/source_{sp_s}.wav',source_wave,sample_rate)
    
    for sp_t in speakers:
        if sp_t != sp_s:
            target_id  = random.randint(3,13)
            target_path = metadata[sp_t][target_id].replace("\\", "/")
            mel_target = np.load(f"{ROOT}/{target_path}")
            emb_trg = torch.from_numpy(metadata[sp_t][1]).unsqueeze(0).to(device)
            for md in model:
                if md == 'autovc':
                    wave = get_trans_mel( 
        autovc,
        mel_source,
        mel_target,
        emb_org,
        emb_trg,
        isAdain=False,
        isVQ=False,
        isAgainVC=False)
                elif md == 'asganvc':
                    wave = get_trans_mel( 
        asganvc,
        mel_source,
        mel_target,
        emb_org,
        emb_trg,
        isAdain=False,
        isVQ=False,
        isAgainVC=False)
                
                elif md == 'vqvc':
                    wave = get_trans_mel( 
        vqvc,
        mel_source,
        mel_target,
        emb_org,
        emb_trg,
        isAdain=False,
        isVQ=True,
        isAgainVC=False)
                    
                else:
                    wave = get_trans_mel( 
        againvc,
        mel_source,
        mel_target,
        emb_org,
        emb_trg,
        isAdain=False,
        isVQ=False,
        isAgainVC=True)   
                save_wave(f'{SAVE_DIR}/{md}/{sp_s}/{sp_t}.wav',wave,sample_rate)    