In [1]:
import os
import torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
import soundfile as sf

from nemo.collections.tts.models.squeezewave import SqueezeWaveModel
from nemo.collections.tts.losses.waveglowloss import WaveGlowLoss
from nemo.collections.tts.data.datalayers import AudioDataset
from nemo.collections.asr.parts.features import FilterbankFeatures
from nemo.collections.tts.modules.squeezewave import OperationMode

# Replace these macros accordingly
ROOT = '/home/vincliu/vocoder/NeMo'
os.chdir(ROOT)

CHECKPOINT = os.path.join('inprogress', 'SqueezeWave', '2020-08-20_09-33-30', 'checkpoints', 'SqueezeWave--last.ckpt')
CONFIG = os.path.join('examples', 'tts', 'conf', 'squeezewave.yaml')
DATA_MANIFEST = os.path.join('data', 'nvidia_ljspeech_test.json')
OUT_FOLDER = os.path.join('samples')
if not os.path.isdir(OUT_FOLDER):
    os.makedirs(OUT_FOLDER)

N_SAMPLES = 10
MAX_WAV_VALUE = 32768.

[NeMo W 2020-08-20 11:25:18 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioToCharDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-20 11:25:18 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioToBPEDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-20 11:25:18 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.AudioLabelDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-20 11:25:18 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text._TarredAudioToTextDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2020-08-20 11:25:18 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset'> is experimental,

In [2]:
# Remove all weight norm reparameterizations
def remove_weightnorm(model):
    squeezewave = model.squeezewave
    assert hasattr(squeezewave, 'wavenet')
    for i in range(len(squeezewave.wavenet)):
        wavenet = squeezewave.wavenet[i]
        wavenet.start = torch.nn.utils.remove_weight_norm(wavenet.start)
        wavenet.cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layer)
        for j in range(wavenet.n_layers):
            res_skip_layer = wavenet.res_skip_layers[j]
            res_skip_layer = torch.nn.utils.remove_weight_norm(res_skip_layer)
    return model

# Load model with its config
def load_squeezewave_model(cfg):
    with open(cfg) as f:
        config = OmegaConf.load(f)
    del config.model['train_ds']
    del config.model['validation_ds']
    del config.model['optim']
    model = SqueezeWaveModel(cfg=config['model'])
    print(model.num_weights)
    return model

# Load weights and prepare model for inference
model = load_squeezewave_model(CONFIG)
checkpoint = torch.load(CHECKPOINT)['state_dict']
model.load_state_dict(checkpoint, strict=True)
model = remove_weightnorm(model)
model = model.eval()
model.mode = OperationMode.validation
model.squeezewave.mode = OperationMode.validation

[NeMo I 2020-08-20 11:25:19 features:301] PADDING: 16
[NeMo I 2020-08-20 11:25:19 features:310] STFT using exact pad
23665184


In [3]:
# Initialize dataset
dataset = AudioDataset(DATA_MANIFEST, n_segments=-1, truncate_to=4096)
dataloader = DataLoader(
    dataset,
    1,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
    num_workers=4,
)

loss_fn = WaveGlowLoss()

for i, (audio, audio_len) in enumerate(dataloader):
    if i == N_SAMPLES: break
    # Get loss and audio reconstruction
    with torch.no_grad():
        z, log_s_list, log_det_W_list, audio_pred, *_ = model(audio=audio, audio_len=audio_len, run_inverse=True)
        loss = loss_fn(z=z, log_s_list=log_s_list, log_det_W_list=log_det_W_list, sigma=model.sigma)

    filename = 'audio_{}.wav'.format(i)
    print('[{}] loss: {}'.format(filename, loss.detach().cpu().numpy()))

    audio_pred = (audio_pred.detach().cpu().numpy().squeeze() * MAX_WAV_VALUE).astype('int16')
    sf.write(os.path.join(OUT_FOLDER, 'synth_' + filename), audio_pred, samplerate=22050)

    audio = (audio.detach().cpu().numpy().squeeze() * MAX_WAV_VALUE).astype('int16')
    sf.write(os.path.join(OUT_FOLDER, 'orig_' + filename), audio, samplerate=22050)

[NeMo I 2020-08-20 11:25:25 collections:171] Dataset loaded with 500 files totalling 0.91 hours
[NeMo I 2020-08-20 11:25:25 collections:172] 0 files were filtered totalling 0.00 hours
[audio_0.wav] loss: -5.542590618133545
[audio_1.wav] loss: -5.741009712219238
[audio_2.wav] loss: -5.875692844390869
[audio_3.wav] loss: -6.169131278991699
[audio_4.wav] loss: -5.820337772369385
[audio_5.wav] loss: -5.547118186950684
[audio_6.wav] loss: -5.7106428146362305
[audio_7.wav] loss: -5.664098739624023
[audio_8.wav] loss: -5.55489444732666
[audio_9.wav] loss: -5.970193386077881
