In [1]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from onmt_modules.misc import sequence_mask
# from model_autopst import Generator_2 as Predictor
from hparams_autopst import hparams
from utils import filter_bank_mean
from fast_decoders import DecodeFunc_Sp
from model_sea import Encoder_2 as Encoder_Code_2
from override_decoder import OnmtDecoder_1 as OnmtDecoder

c:\users\aakri\appdata\local\programs\python\python37\lib\site-packages\numpy\.libs\libopenblas.JPIJNSWNNAN3CE6LLI5FWSPHUT2VXMTH.gfortran-win_amd64.dll
c:\users\aakri\appdata\local\programs\python\python37\lib\site-packages\numpy\.libs\libopenblas.NOIJJG62EMASZI6NYURL6JBKM4EVBGM7.gfortran-win_amd64.dll
c:\users\aakri\appdata\local\programs\python\python37\lib\site-packages\numpy\.libs\libopenblas.PYQHXLVVQ7VESDPUVUADXEVJOBGHJPAY.gfortran-win_amd64.dll
  stacklevel=1)


In [2]:
from onmt_modules.misc import sequence_mask
from onmt_modules.embeddings import PositionalEncoding
from onmt_modules.encoder_transformer import TransformerEncoder as OnmtEncoder

In [3]:
class Prenet(nn.Module):
    def __init__(self, dim_input, dim_output, dropout=0.1):
        super().__init__() 
        
        mlp = nn.Linear(dim_input, dim_output, bias=True)
        pe = PositionalEncoding(dropout, dim_output, 1600)
        
        self.make_prenet = nn.Sequential()
        self.make_prenet.add_module('mlp', mlp)
        self.make_prenet.add_module('pe', pe)
        
        self.word_padding_idx = 1
        
    def forward(self, source, step=None):
        
        for i, module in enumerate(self.make_prenet._modules.values()):
            if i == len(self.make_prenet._modules.values()) - 1:
                source = module(source, step=step)
            else:
                source = module(source)
                
        return source

In [4]:
class Encoder_Tx_Spk(nn.Module):
    """
    Text Encoder
    """
    def __init__(self, hparams):
        super().__init__() 
        
        prenet = Prenet(hparams.dim_code+hparams.dim_spk, 
                        hparams.enc_rnn_size)
        self.encoder = OnmtEncoder.from_opt(hparams, prenet)
        
    def forward(self, src, src_lengths, spk_emb):
        
        spk_emb = spk_emb.unsqueeze(0).expand(src.size(0),-1,-1)
        src_spk = torch.cat((src, spk_emb), dim=-1)
        enc_states, memory_bank, src_lengths = self.encoder(src_spk, src_lengths)
        
        return enc_states, memory_bank, src_lengths

In [5]:
class Decoder_Sp(nn.Module):
    """
    Speech Decoder
    """
    def __init__(self, hparams):
        super().__init__() 
        
        self.dim_freq = hparams.dim_freq
        self.max_decoder_steps = hparams.dec_steps_sp
        self.gate_threshold = hparams.gate_threshold
        
        prenet = Prenet(hparams.dim_freq, hparams.dec_rnn_size)
        self.decoder = OnmtDecoder.from_opt(hparams, prenet)

        self.postnet = nn.Linear(hparams.dec_rnn_size, 
                                 hparams.dim_freq+1, bias=True)
        
    def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths):
        
        dec_outs, attns = self.decoder(tgt, memory_bank, step=None, 
                                       memory_lengths=memory_lengths,
                                       tgt_lengths=tgt_lengths)
        spect_gate = self.postnet(dec_outs)
        spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
        
        return spect, gate

In [6]:
class Predictor(nn.Module):
    '''
    async stage 2
    '''
    def __init__(self, hparams):
        super().__init__() 
        
        self.encoder_cd = Encoder_Code_2(hparams)
        self.encoder_tx = Encoder_Tx_Spk(hparams)
        self.decoder_sp = Decoder_Sp(hparams)   
        self.encoder_spk = nn.Linear(hparams.dim_spk, 
                                     hparams.enc_rnn_size, bias=True)
        self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp')
        
        
    def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short,
                      tgt_spect, len_spect, 
                      spk_emb):
        
        cd_long = self.encoder_cd(cep_in, mask_long)
        fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1))
        
        cd_short = torch.bmm(fb.detach(), cd_long.detach())
        
        spk_emb_1 = self.encoder_spk(spk_emb)
        
        # text to speech
        _, memory_tx, _ = self.encoder_tx(cd_short.transpose(1,0), len_short, 
                                          spk_emb)
        memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
        self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
        spect_out, gate_sp_out \
        = self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_short+1)
        
        return spect_out, gate_sp_out
    
    
    def infer_onmt(self, cep_in, mask_long, len_spect,
                   spk_emb):
        
        cd_long = self.encoder_cd(cep_in, mask_long)
        
        spk_emb_1 = self.encoder_spk(spk_emb)
        
        # text to speech
        _, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect, 
                                          spk_emb)
        memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0)
        self.decoder_sp.decoder.init_state(memory_tx_spk, None, None)
        spect_output, len_spect_out, stop_sp_output \
        = self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1, 
                                 self.decoder_sp.decoder, 
                                 self.decoder_sp.postnet)
        
        return spect_output, len_spect_out

In [7]:
device = 'cuda:0'

P = Predictor(hparams).eval().to(device)

checkpoint = torch.load('./assets/580000-P.ckpt', map_location=lambda storage, loc: storage)  
P.load_state_dict(checkpoint['model'], strict=True)
print('Loaded predictor .....................................................')

dict_test = pickle.load(open('./assets/test_vctk.meta', 'rb'))

Loaded predictor .....................................................


In [8]:
print(len(dict_test['p231']))

print(dict_test['p231'])

2
OrderedDict([('001', (array([[ 0.12554671, -0.16963401, -0.58115584, ..., -0.51349175,
         1.2145282 , -0.39891607],
       [ 0.6724523 , -0.20496637, -0.6058039 , ..., -0.01747223,
         1.2504997 , -0.6045832 ],
       [ 0.48696825, -0.3501245 , -0.67875063, ...,  0.22466417,
         1.037844  , -0.31743672],
       ...,
       [ 0.0410398 ,  0.11647096, -0.502512  , ...,  0.35522646,
         1.1331956 ,  1.1371542 ],
       [-0.47664598,  0.08350065, -0.511562  , ..., -0.29939383,
         0.00328085,  0.9134802 ],
       [-1.1690527 ,  0.08582662, -0.736655  , ..., -0.31217238,
         0.54107565,  0.69253266]], dtype=float32), array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [9]:
spect_vc = OrderedDict()

uttrs = [('p231', 'p270', '001'),
         ('p270', 'p231', '001'),
         ('p231', 'p245', '003001'),
         ('p245', 'p231', '003001'),
         ('p239', 'p270', '024002'),
         ('p270', 'p239', '024002')]


for uttr in uttrs:
        
    cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
    cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
    len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
    real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
    
    _, spk_emb = dict_test[uttr[1]][uttr[2]]
    spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
    
    with torch.no_grad():
        spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
                                               real_mask_A,
                                               len_real_A,
                                               spk_emb_B)
    
    uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
        
    spect_vc[f'{uttr[0]}_{uttr[1]}_{uttr[2]}'] = uttr_tgt

In [None]:
# spectrogram to waveform
# Feel free to use other vocoders
# This cell requires some preparation to work, please see the corresponding part in AutoVC
import torch
import librosa
import pickle
import os
from synthesis import build_model
from synthesis import wavegen
import soundfile as sf

model = build_model().to(device)
checkpoint = torch.load("./assets/checkpoint_step001000000_ema.pth")
model.load_state_dict(checkpoint["state_dict"])
outs = []

for name, sp in spect_vc.items():
    print(name)
    
    waveform = wavegen(model, c=sp)  
    outs.append(waveform)
    sf.write('./assets/'+name+'.wav', waveform, 16000)
#     write('./assets/'+name+'.wav', 16000, waveform.astype(np.int16))

  0%|                                                                                | 3/18944 [00:00<14:27, 21.84it/s]

p231_p270_001


 88%|███████████████████████████████████████████████████████████████████         | 16716/18944 [12:02<01:41, 21.86it/s]