In [1]:
import os

# Replace these macros accordingly
ROOT = '/home/vincliu/vocoder/NeMo'
os.chdir(ROOT)

CHECKPOINT = os.path.join('nemo_experiments', 'SqueezeWave', '2020-08-30_12-13-34', 'checkpoints', 'SqueezeWave--last.ckpt')
CONFIG = os.path.join('examples', 'tts', 'conf', 'squeezewave.yaml')
DATA_MANIFEST = os.path.join('data', 'nvidia_ljspeech_test.json')
OUT_FOLDER = os.path.join('samples')
if not os.path.isdir(OUT_FOLDER):
    os.makedirs(OUT_FOLDER)

N_SAMPLES = 100             # number of samples to run validation/inference on
MAX_WAV_VALUE = 32768.      # max wav value for scaling audio for writing to disk
DENOISER_STRENGTH = 0.01    # strength of denoiser to remove model bias
SIGMA = 0.6                 # std dev of gaussian distribution to sample from
NORMALIZE = True            # whether to normalize generated audio
SAMPLE_RATE = 22050         # sampling rate of all LJ-Speech audio files

In [2]:
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.helpers.helpers import remove_weightnorm
from nemo.collections.tts.modules.squeezewave import OperationMode
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel

# Load model with its config
def load_squeezewave_model(cfg):
    with open(cfg) as f:
        config = OmegaConf.load(f)
    del config.model['train_ds']
    del config.model['validation_ds']
    del config.model['optim']
    model = SqueezeWaveModel(cfg=config['model'])
    print('Number of trainable parameters:', model.num_weights)
    return model

# Load weights and prepare model for inference
model = load_squeezewave_model(CONFIG)
checkpoint = torch.load(CHECKPOINT)['state_dict']
model.load_state_dict(checkpoint, strict=True)

model.squeezewave = remove_weightnorm(model.squeezewave)
model = model.eval()
model.mode = OperationMode.validation
model.squeezewave.mode = OperationMode.validation

[NeMo W 2020-08-31 14:17:40 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioToCharDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 14:17:40 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioToBPEDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 14:17:40 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioLabelDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 14:17:40 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text._TarredAudioToTextDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 14:17:40 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset'> is experimental,

In [3]:
from torch.utils.data import DataLoader

from nemo.collections.tts.losses.waveglowloss import WaveGlowLoss
from nemo.collections.tts.data.datalayers import AudioDataset
from nemo.collections.tts.modules.denoiser import SqueezeWaveDenoiser

dataloader = DataLoader(
    AudioDataset(DATA_MANIFEST, n_segments=-1, truncate_to=4096), 1,
    shuffle=False, pin_memory=True, drop_last=False, num_workers=4,
)

loss_fn = WaveGlowLoss()

if DENOISER_STRENGTH > 0:
    denoiser = SqueezeWaveDenoiser(
        model, n_mel=80, filter_length=1024, hop_length=512, win_length=1024, window='hann',
    )

[NeMo I 2020-08-31 14:17:46 collections:171] Dataset loaded with 500 files totalling 0.91 hours
[NeMo I 2020-08-31 14:17:46 collections:172] 0 files were filtered totalling 0.00 hours


In [4]:
import soundfile as sf
from pystoi import stoi
from maracas.maracas import asl_meter   # requires numba==0.44 for autojit (librosa usually installs numba==0.48)

from nemo.collections.asr.parts.features import FilterbankFeatures

for i, (audio, audio_len) in enumerate(dataloader):
    if i == N_SAMPLES: break

    # Get loss and audio reconstruction
    with torch.no_grad():
        spect, spect_len = model.audio_to_melspec_precessor(audio, audio_len)
        z, log_s_list, log_det_W_list, audio_pred, *_ = model.squeezewave(
            spect=spect, audio=audio, run_inverse=True, sigma=SIGMA,
        )
        loss = loss_fn(
            z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=model.sigma,
        )

    # Remove model bias
    if DENOISER_STRENGTH > 0:
        audio_pred = denoiser(audio_pred, strength=DENOISER_STRENGTH)
    
    audio_pred = audio_pred.detach().cpu().numpy().squeeze()
    audio = audio.detach().cpu().numpy().squeeze()

    # Normalize active speech levels (asl)
    if NORMALIZE:
        asl_level = -26
        audio_pred = audio_pred / 10 ** (asl_meter(audio_pred, SAMPLE_RATE) / 20) * 10 ** (asl_level / 20)

    # Compute stoi (short-time objective intelligibility)
    stoi_score = stoi(audio, audio_pred, SAMPLE_RATE, extended=False)

    # Save original and synthesized audio to disk
    filename = 'audio_{}.wav'.format(i)
    audio_pred = (audio_pred * MAX_WAV_VALUE).astype('int16')
    sf.write(os.path.join(OUT_FOLDER, 'synth_' + filename), audio_pred, samplerate=SAMPLE_RATE)
    audio = (audio * MAX_WAV_VALUE).astype('int16')
    sf.write(os.path.join(OUT_FOLDER, 'orig_' + filename), audio, samplerate=SAMPLE_RATE)

    print('[{}] loss: {:.5f} stoi: {:.5f}'.format(filename, loss.item(), stoi_score.item()))

[audio_0.wav] loss: -5.88563 stoi: 0.96898
[audio_1.wav] loss: -6.07639 stoi: 0.97734
[audio_2.wav] loss: -6.20760 stoi: 0.97059
[audio_3.wav] loss: -6.47904 stoi: 0.97270
[audio_4.wav] loss: -6.14631 stoi: 0.97634
[audio_5.wav] loss: -5.89166 stoi: 0.97234
[audio_6.wav] loss: -6.03365 stoi: 0.97227
[audio_7.wav] loss: -6.02054 stoi: 0.97181
[audio_8.wav] loss: -5.90490 stoi: 0.96592
[audio_9.wav] loss: -6.27257 stoi: 0.96852
[audio_10.wav] loss: -6.33750 stoi: 0.96953
[audio_11.wav] loss: -5.98078 stoi: 0.96732
[audio_12.wav] loss: -6.07245 stoi: 0.96794
[audio_13.wav] loss: -5.96641 stoi: 0.96922
[audio_14.wav] loss: -6.10163 stoi: 0.97090
[audio_15.wav] loss: -5.77293 stoi: 0.96386
[audio_16.wav] loss: -6.27294 stoi: 0.96726
[audio_17.wav] loss: -6.22984 stoi: 0.97070
[audio_18.wav] loss: -6.31492 stoi: 0.96074
[audio_19.wav] loss: -6.01833 stoi: 0.97529
[audio_20.wav] loss: -6.20214 stoi: 0.97651
[audio_21.wav] loss: -5.95097 stoi: 0.97590
[audio_22.wav] loss: -5.87310 stoi: 0.9783