/
synthesize.py
91 lines (79 loc) · 3.71 KB
/
synthesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from config import ConfigArgs as args
import os, sys
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
import numpy as np
import pandas as pd
from model import Text2Mel, SSRN
from data import TextDataset, synth_collate_fn, load_vocab
import utils
from scipy.io.wavfile import write
def synthesize(t2m, ssrn, data_loader, batch_size=100):
'''
DCTTS Architecture
Text --> Text2Mel --> SSRN --> Wav file
'''
# Text2Mel
idx2char = load_vocab()[-1]
with torch.no_grad():
print('='*10, ' Text2Mel ', '='*10)
total_mel_hats = torch.zeros([len(data_loader.dataset), args.max_Ty, args.n_mels]).to(DEVICE)
mags = torch.zeros([len(data_loader.dataset), args.max_Ty*args.r, args.n_mags]).to(DEVICE)
for step, (texts, _, _) in enumerate(data_loader):
texts = texts.to(DEVICE)
prev_mel_hats = torch.zeros([len(texts), args.max_Ty, args.n_mels]).to(DEVICE)
for t in tqdm(range(args.max_Ty-1), unit='B', ncols=70):
mel_hats, A = t2m(texts, prev_mel_hats) # mel: (N, Ty/r, n_mels)
prev_mel_hats[:, t+1, :] = mel_hats[:, t, :]
total_mel_hats[step*batch_size:(step+1)*batch_size, :, :] = prev_mel_hats
print('='*10, ' Alignment ', '='*10)
alignments = A.cpu().detach().numpy()
visual_texts = texts.cpu().detach().numpy()
for idx in range(len(alignments)):
text = [idx2char[ch] for ch in visual_texts[idx]]
utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{}.png'.format(idx))
print('='*10, ' SSRN ', '='*10)
# Mel --> Mag
mags[step*batch_size:(step+1)*batch_size:, :, :] = \
ssrn(total_mel_hats[step*batch_size:(step+1)*batch_size, :, :]) # mag: (N, Ty, n_mags)
mags = mags.cpu().detach().numpy()
print('='*10, ' Vocoder ', '='*10)
for idx in trange(len(mags), unit='B', ncols=70):
wav = utils.spectrogram2wav(mags[idx])
write(os.path.join(args.sampledir, '{}.wav'.format(idx+1)), args.sr, wav)
return None
def main():
testset = TextDataset(args.testset)
test_loader = DataLoader(dataset=testset, batch_size=args.test_batch, drop_last=False,
shuffle=False, collate_fn=synth_collate_fn, pin_memory=True)
t2m = Text2Mel().to(DEVICE)
ssrn = SSRN().to(DEVICE)
ckpt = pd.read_csv(os.path.join(args.logdir, t2m.name, 'ckpt.csv'), sep=',', header=None)
ckpt.columns = ['models', 'loss']
ckpt = ckpt.sort_values(by='loss', ascending=True)
state = torch.load(os.path.join(args.logdir, t2m.name, ckpt.models.loc[0]))
t2m.load_state_dict(state['model'])
args.global_step = state['global_step']
ckpt = pd.read_csv(os.path.join(args.logdir, ssrn.name, 'ckpt.csv'), sep=',', header=None)
ckpt.columns = ['models', 'loss']
ckpt = ckpt.sort_values(by='loss', ascending=True)
state = torch.load(os.path.join(args.logdir, ssrn.name, ckpt.models.loc[0]))
ssrn.load_state_dict(state['model'])
print('All of models are loaded.')
t2m.eval()
ssrn.eval()
if not os.path.exists(os.path.join(args.sampledir, 'A')):
os.makedirs(os.path.join(args.sampledir, 'A'))
synthesize(t2m, ssrn, test_loader, args.test_batch)
if __name__ == '__main__':
gpu_id = int(sys.argv[1])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_id)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main()