In [None]:
import torch
from modules.FastDiff.module.FastDiff_model import FastDiff
from utils import audio
from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule
import numpy as np

HOP_SIZE = 256  # for 22050 frequency

# download checkpoint to this folder
state_dict = torch.load("./checkpoints/FastDiff_tacotron/model_ckpt_steps_500000.ckpt")["state_dict"]["model"]

model = FastDiff().cuda()
model.load_state_dict(state_dict)

train_noise_schedule = noise_schedule = torch.linspace(1e-06, 0.01, 1000)
diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)

# map diffusion hyperparameters to gpu
for key in diffusion_hyperparams:
    if key in ["beta", "alpha", "sigma"]:
        diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()
diffusion_hyperparams = diffusion_hyperparams

# load noise schedule for 8 sampling steps
#noise_schedule = torch.FloatTensor([6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5]).cuda()
# load noise schedule for 4 sampling steps
noise_schedule = torch.FloatTensor([3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]).cuda()


In [None]:
# Text-to-speech
tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')
tacotron2 = tacotron2.to("cuda").eval()

text = "Welcome to a conditional diffusion probabilistic model capable of generating high fidelity speech efficiently."
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')
sequences, lengths = utils.prepare_input_sequence([text])

with torch.no_grad():
    mels, _, _ = tacotron2.infer(sequences, lengths)

In [None]:
# Speech-to-waveform

audio_length = mels.shape[-1] * HOP_SIZE
pred_wav = sampling_given_noise_schedule(
    model, (1, 1, audio_length), diffusion_hyperparams, noise_schedule,
    condition=mels, ddim=False, return_sequence=False)

pred_wav = pred_wav / pred_wav.abs().max()
audio.save_wav(pred_wav.view(-1).cpu().float().numpy(), './test.wav', 22050)