In [1]:
import os
import librosa
import time
from pathlib import Path

import torch

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import expit

from tacotron2.models.tacotron2 import Tacotron2Jit
from tacotron2.hparams import HParams
from tacotron2.tokenizers import RussianPhonemeTokenizer
from tacotron2.evaluators import BaseEvaluator, EmbeddingEvaluator
from tacotron2.evaluators import get_evaluator, plot_syntesis_result

from IPython.display import Audio

%matplotlib inline

In [2]:
def benchmark(model, inputs, n=30, **kwargs):
    with torch.jit.optimized_execution(True):
        times = []
        for _ in range(n):
            ts = time.time()
            outputs = model(inputs, **kwargs)
            te = time.time()
            times.append(te - ts)
    print(f"Mean: {np.mean(times)}")
    print(f"Std: {np.std(times)}")

### Tacotron2

---

In [3]:
tokenizer = RussianPhonemeTokenizer()

In [4]:
DEVICE = 'cuda:0'

In [5]:
text = 'Привет, коллеги! Сегодня мы с вами будем учить дискретное преобразование Фурье.'
inputs = torch.LongTensor(tokenizer.encode(text)).unsqueeze(0).to(DEVICE)



---

In [6]:
hparams_tacotron = HParams.from_yaml(
    'output/melnik2/hparams.yaml'
)
#hparams_tacotron.sample_embedding_dim = 256
hparams_tacotron.n_symbols = 152
hparams_tacotron.gate_threshold = 0.5
hparams_tacotron.max_decoder_steps = 1600

In [7]:
model = Tacotron2Jit(hparams_tacotron).to(DEVICE)
weights = torch.load('output/melnik2/models/model_best.pth', map_location=DEVICE)
model.load_state_dict(weights['model_state_dict'])
model.eval();

In [8]:
benchmark(model, inputs, 10)

Mean: 0.6111136913299561
Std: 0.04149939475990331


---

In [9]:
model_scripted = torch.jit.script(model)

In [10]:
benchmark(model_scripted, inputs, 10)

Mean: 0.5428481101989746
Std: 0.16673798155911645


In [11]:
mel_outputs, mel_outputs_postnet, gates, alignments = model_scripted(inputs)

In [12]:
# del model
# del model_scripted
# torch.cuda.empty_cache()

---

### Waveglow

In [13]:
from waveglow.models import WaveGlow

---

In [14]:
hparams_wg = HParams.from_yaml('../waveglow/configs/hparams.default.yaml')
hparams_wg.hop_length = 256
hparams_wg.win_length = 1024

waveglow = WaveGlow(hparams_wg).to(DEVICE)

In [15]:
vocoder_checkpoint_path = '../waveglow/models/pretrained_waweglow.pt'
vocoder_loaded_weights = torch.load(vocoder_checkpoint_path, map_location=DEVICE)
if 'model_state_dict' in vocoder_loaded_weights:
    waveglow.load_state_dict(
        torch.load(vocoder_checkpoint_path, map_location=DEVICE)['model_state_dict']
    )
else:
    waveglow.load_state_dict(
        torch.load(vocoder_checkpoint_path, map_location=DEVICE)
    )
    
waveglow.eval()

In [16]:
benchmark(waveglow.infer, mel_outputs_postnet, 50, sigma=torch.Tensor([1]))

Mean: 0.3274675989151001
Std: 0.02406716399762775


In [17]:
signal = waveglow.infer(mel_outputs_postnet, 1.)

In [18]:
Audio(data=signal.cpu().numpy(), rate=22050)

---

In [27]:
scripted_waveglow = torch.jit.script(waveglow)

In [28]:
benchmark(scripted_waveglow.infer, mel_outputs_postnet, 50, sigma=torch.Tensor([1]))

Mean: 0.3244642639160156
Std: 0.009629699187366591


In [23]:
signal_scripted = scripted_waveglow.infer(mel_outputs_postnet, torch.Tensor([1.]))

In [24]:
Audio(data=signal_scripted.cpu().numpy(), rate=22050)

In [29]:
(54 + 32) / (61 + 32)

0.9247311827956989