In [None]:
import datetime as dt
from pathlib import Path

import IPython.display as ipd
import numpy as np
import soundfile as sf
import torch
import pickle

# Hifigan imports
from prosodyfm.hifigan.config import v1
from prosodyfm.hifigan.denoiser import Denoiser
from prosodyfm.hifigan.env import AttrDict
from prosodyfm.hifigan.models import Generator as HiFiGAN

from prosodyfm.models.prosodyfm import ProsodyFM
from prosodyfm.text import sequence_to_text, text_to_sequence
from prosodyfm.utils.model import denormalize
from prosodyfm.utils.utils import get_user_data_dir, intersperse

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
# This allows for real time code changes being reflected in the notebook, no need to restart the kernel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
PROSODYFM_CHECKPOINT = './checkpoints/checkpoint_epoch=349.ckpt'
HIFIGAN_CHECKPOINT = './hifigan/released_checkpoints/g_02500000'
OUTPUT_FOLDER = "./demo_samples/with_boundary_gst"

In [None]:
def load_model(checkpoint_path):
    model = ProsodyFM.load_from_checkpoint(checkpoint_path, map_location=device)
    model.eval()
    return model
count_params = lambda x: f"{sum(p.numel() for p in x.parameters()):,}"


model = load_model(PROSODYFM_CHECKPOINT)
print(f"Model loaded! Parameter count: {count_params(model)}")

## Load HiFi-GAN (Vocoder)

In [None]:
def load_vocoder(checkpoint_path):
    h = AttrDict(v1)
    hifigan = HiFiGAN(h).to(device)
    hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
    _ = hifigan.eval()
    hifigan.remove_weight_norm()
    return hifigan

vocoder = load_vocoder(HIFIGAN_CHECKPOINT)
denoiser = Denoiser(vocoder, mode='zeros')

In [None]:
with open("./libritts_audio_sid_text_test_filelist_b_pitchseg_t5_final.pkl", "rb") as f:
    t5_filelist = pickle.load(f)

In [None]:
def get_testing_data(test_filelist, index):
    info = test_filelist[index]
    wav_file = info[0]
    text = info[2]
    sequence = text_to_sequence(text, ['english_cleaners2'])
    x = torch.tensor(intersperse(sequence, 0),dtype=torch.long, device=device)[None]
    x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)
    x_phones = sequence_to_text(x.squeeze(0).tolist())
    
    spk = info[1]
    spk = torch.tensor([int(spk)], device=device, dtype=torch.long)
    
    #x_pitch_seg = custom_pitch
    x_pitch_seg = info[3]
    x_pitch_seg = torch.from_numpy(x_pitch_seg).float().to(device)

    x_pitch_seg_lengths = torch.from_numpy(info[4]).long().to(device)
    x_last_word_num = torch.tensor([info[5]], device=device, dtype=torch.long)
        
    return {
    'x': x,
    'x_orig': text,
    'x_lengths': x_lengths,
    'spk': spk,
    'x_pitch_seg': x_pitch_seg,
    'x_pitch_seg_lengths': x_pitch_seg_lengths,
    'x_last_word_num': x_last_word_num,
    'x_phones': x_phones,
    'wav_file': wav_file
    }


In [None]:
@torch.inference_mode()
def synthesise(test_filelist, index):
    processed_data = get_testing_data(test_filelist, index)
    start_t = dt.datetime.now()
    output = model.synthesise(
        x = processed_data['x'], 
        x_lengths = processed_data['x_lengths'],
        n_timesteps=10,
        spks=processed_data['spk'],
        x_pitch_seg = processed_data['x_pitch_seg'],
        x_pitch_seg_lengths = processed_data['x_pitch_seg_lengths'],
        x_last_word_num = processed_data['x_last_word_num'],
        length_scale=1.0,
        temperature=0.667,
    )
    

    # merge everything to one dict    
    output.update({'start_t': start_t, **processed_data})
    return output

@torch.inference_mode()
def to_waveform(mel, vocoder):
    audio = vocoder(mel).clamp(-1, 1)
    audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
    return audio.cpu().squeeze()
    
def save_to_folder(filename: str, output: dict, folder: str):
    folder = Path(folder)
    folder.mkdir(exist_ok=True, parents=True)
    #np.save(folder / f'{filename}', output['mel'].cpu().numpy())
    sf.write(folder / f'{filename}', output['waveform'], 22050, 'PCM_24')

## Setup text to synthesise

### Hyperparameters

In [None]:
## Number of ODE Solver steps
n_timesteps = 10

## Changes to the speaking rate
length_scale=1.0

## Sampling temperature
temperature = 0.667

## Synthesis

In [None]:
outputs, rtfs = [], []
rtfs_w = []

for i in range(len(t5_filelist)):
    output = synthesise(t5_filelist, i) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
    output['waveform'] = to_waveform(output['mel'], vocoder)

    # Compute Real Time Factor (RTF) with HiFi-GAN
    t = (dt.datetime.now() - output['start_t']).total_seconds()
    rtf_w = t * 22050 / (output['waveform'].shape[-1])

    ## Pretty print
    print(f"{'*' * 53}")
    print(f"Input text - {i}")
    print(f"{'-' * 53}")
    print(output['x_orig'])
    print(f"{'*' * 53}")
    print(f"Phonetised text - {i}")
    print(f"{'-' * 53}")
    print(output['x_phones'])
    print(f"{'*' * 53}")
    print(f"Speaker Id - {output['spk']}")
    print(f"{'-' * 53}")
    print(output['wav_file'])
    print(f"{'*' * 53}")
    print(f"RTF:\t\t{output['rtf']:.6f}")
    print(f"RTF Waveform:\t{rtf_w:.6f}")
    rtfs.append(output['rtf'])
    rtfs_w.append(rtf_w)

    ## Display the synthesised waveform
    ipd.display(ipd.Audio(output['waveform'], rate=22050))
    wav_name = output['wav_file'].split('/')[-1]
    ## Save the generated waveform
    save_to_folder(wav_name, output, OUTPUT_FOLDER)

print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
print(f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}")