In [1]:
import os

# Replace these macros accordingly
ROOT = '/home/vincliu/vocoder/NeMo'
os.chdir(ROOT)

CHECKPOINT = os.path.join('nemo_experiments', 'SqueezeWave', '2020-08-30_12-13-34', 'checkpoints', 'SqueezeWave--last.ckpt')
CONFIG = os.path.join('examples', 'tts', 'conf', 'squeezewave.yaml')
DATA_MANIFEST = os.path.join('data', 'nvidia_ljspeech_test.json')
OUT_FOLDER = os.path.join('samples')
if not os.path.isdir(OUT_FOLDER):
    os.makedirs(OUT_FOLDER)

N_SAMPLES = 10              # number of samples to run validation/inference on
MAX_WAV_VALUE = 32768.      # max wav value for scaling audio for writing to disk
DENOISER_STRENGTH = 0.01    # strength of denoiser to remove model bias
SIGMA = 0.6                 # std dev of gaussian distribution to sample from
NORMALIZE = True            # whether to normalize generated audio
SAMPLE_RATE = 22050         # sampling rate of all LJ-Speech audio files

In [2]:
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.modules.squeezewave import OperationMode
from nemo.collections.tts.models.squeezewave import SqueezeWaveModel

# Remove all weight norm reparameterizations
def remove_weightnorm(model):
    squeezewave = model.squeezewave
    assert hasattr(squeezewave, 'wavenet')
    for i in range(len(squeezewave.wavenet)):
        wavenet = squeezewave.wavenet[i]
        wavenet.start = torch.nn.utils.remove_weight_norm(wavenet.start)
        wavenet.cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layer)
        for j in range(wavenet.n_layers):
            res_skip_layer = wavenet.res_skip_layers[j]
            res_skip_layer = torch.nn.utils.remove_weight_norm(res_skip_layer)
    return model

# Load model with its config
def load_squeezewave_model(cfg):
    with open(cfg) as f:
        config = OmegaConf.load(f)
    del config.model['train_ds']
    del config.model['validation_ds']
    del config.model['optim']
    model = SqueezeWaveModel(cfg=config['model'])
    print('Number of trainable parameters:', model.num_weights)
    return model

# Load weights and prepare model for inference
model = load_squeezewave_model(CONFIG)
checkpoint = torch.load(CHECKPOINT)['state_dict']
model.load_state_dict(checkpoint, strict=True)
model = remove_weightnorm(model)
model = model.eval()
model.mode = OperationMode.validation
model.squeezewave.mode = OperationMode.validation

[NeMo W 2020-08-31 13:52:07 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioToCharDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 13:52:07 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioToBPEDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 13:52:07 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioLabelDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 13:52:07 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text._TarredAudioToTextDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-31 13:52:07 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset'> is experimental,

In [3]:
from torch.utils.data import DataLoader

from nemo.collections.tts.losses.waveglowloss import WaveGlowLoss
from nemo.collections.tts.data.datalayers import AudioDataset

dataset = AudioDataset(DATA_MANIFEST, n_segments=-1, truncate_to=4096)
dataloader = DataLoader(
    dataset, 1, shuffle=False, pin_memory=True, drop_last=False, num_workers=4,
)

loss_fn = WaveGlowLoss()

[NeMo I 2020-08-31 13:52:12 collections:171] Dataset loaded with 500 files totalling 0.91 hours
[NeMo I 2020-08-31 13:52:12 collections:172] 0 files were filtered totalling 0.00 hours


In [4]:
from nemo.collections.asr.parts.features import STFTExactPad

class Denoiser(torch.nn.Module):
    def __init__(self, model, n_mel=80, filter_length=1024, hop_length=512, win_length=1024, window='hann'):
        super().__init__()

        self.stft = STFTExactPad(
            filter_length=filter_length, hop_length=hop_length, win_length=win_length, window=window,
        ).to(model.device)

        with torch.no_grad():
            mel_input = torch.zeros((n_mel, 88)).to(model.device)
            bias_audio = model.convert_spectrogram_to_audio(mel_input, sigma=0.0)
            bias_spec, _ = self.stft.transform(bias_audio.unsqueeze(0))
            self.bias_spec = bias_spec[:, :, 0][:, :, None]

        # Reset mode to validation since `model.convert_spectrogram_to_audio` sets it to infer
        model.mode = OperationMode.validation
        model.squeezewave.mode = OperationMode.validation

    def forward(self, audio, strength=0.1):
        audio_spec, audio_angles = self.stft.transform(audio)
        audio_spec_denoised = audio_spec - self.bias_spec * strength
        audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
        audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
        return audio_denoised

if DENOISER_STRENGTH > 0:
    denoiser = Denoiser(
        model, n_mel=80, filter_length=1024, hop_length=512, win_length=1024, window='hann',
    )

In [5]:
import soundfile as sf
from pystoi import stoi
from maracas.maracas import asl_meter   # requires numba==0.44 (librosa usually installs numba==0.48)

from nemo.collections.asr.parts.features import FilterbankFeatures

for i, (audio, audio_len) in enumerate(dataloader):
    if i == N_SAMPLES: break

    # Get loss and audio reconstruction
    with torch.no_grad():
        spect, spect_len = model.audio_to_melspec_precessor(audio, audio_len)
        z, log_s_list, log_det_W_list, audio_pred, *_ = model.squeezewave(
            spect=spect, audio=audio, run_inverse=True, sigma=SIGMA,
        )
        loss = loss_fn(
            z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=model.sigma,
        )

    # Remove model bias
    if DENOISER_STRENGTH > 0:
        audio_pred = denoiser(audio_pred, strength=DENOISER_STRENGTH)
    
    audio_pred = audio_pred.detach().cpu().numpy().squeeze()
    audio = audio.detach().cpu().numpy().squeeze()

    # Normalize active speech levels (asl)
    if NORMALIZE:
        asl_level = -26
        audio_pred = audio_pred / 10 ** (asl_meter(audio_pred, SAMPLE_RATE) / 20) * 10 ** (asl_level / 20)

    # Compute stoi (short-time objective intelligibility)
    stoi_score = stoi(audio, audio_pred, SAMPLE_RATE, extended=False)

    # Save original and synthesized audio to disk
    filename = 'audio_{}.wav'.format(i)
    audio_pred = (audio_pred * MAX_WAV_VALUE).astype('int16')
    sf.write(os.path.join(OUT_FOLDER, 'synth_' + filename), audio_pred, samplerate=SAMPLE_RATE)
    audio = (audio * MAX_WAV_VALUE).astype('int16')
    sf.write(os.path.join(OUT_FOLDER, 'orig_' + filename), audio, samplerate=SAMPLE_RATE)

    print('[{}] loss: {:.5f} stoi: {:.5f}'.format(filename, loss.item(), stoi_score.item()))

[audio_0.wav] loss: -5.84719 stoi: 0.96552
[audio_1.wav] loss: -6.03890 stoi: 0.97358
[audio_2.wav] loss: -6.17264 stoi: 0.96632
[audio_3.wav] loss: -6.43534 stoi: 0.97494
[audio_4.wav] loss: -6.10541 stoi: 0.97410
[audio_5.wav] loss: -5.85458 stoi: 0.97399
[audio_6.wav] loss: -5.99239 stoi: 0.97245
[audio_7.wav] loss: -5.97652 stoi: 0.97298
[audio_8.wav] loss: -5.86361 stoi: 0.96969
[audio_9.wav] loss: -6.22685 stoi: 0.97078
[audio_10.wav] loss: -6.29562 stoi: 0.97012
[audio_11.wav] loss: -5.93088 stoi: 0.96683
[audio_12.wav] loss: -6.02697 stoi: 0.96641
[audio_13.wav] loss: -5.92783 stoi: 0.97089
[audio_14.wav] loss: -6.05657 stoi: 0.96787
[audio_15.wav] loss: -5.73989 stoi: 0.96716
[audio_16.wav] loss: -6.23188 stoi: 0.96742
[audio_17.wav] loss: -6.18484 stoi: 0.96912
[audio_18.wav] loss: -6.27311 stoi: 0.96028
[audio_19.wav] loss: -5.97458 stoi: 0.97047
[audio_20.wav] loss: -6.16458 stoi: 0.97414
[audio_21.wav] loss: -5.91340 stoi: 0.97391
[audio_22.wav] loss: -5.83040 stoi: 0.9761