In [13]:
from pathlib import Path
import librosa
import torch
from argparse import ArgumentParser
import matplotlib
import h5py
import tqdm
from IPython.display import Audio, display

import sys
sys.path += ['../src']

import utils
import wavenet_models
from utils import save_audio
from wavenet import WaveNet
from wavenet_generator import WavenetGenerator
from nv_wavenet_generator import NVWavenetGenerator
from nv_wavenet_generator import Impl

In [14]:
# Inference params
checkpoint = Path('../checkpoints/pretrained_musicnet/bestmodel')
decoders = [0, 1, 2, 3, 4, 5]
batch_size = 1
rate = 16000
split_size = 20
file_paths = [Path('test.wav')]

In [15]:
def disp(x, decoder_ix):
    wav = utils.inv_mu_law(x.cpu().numpy())
    print(f'Decoder: {decoder_ix}')
    print(f'X min: {x.min()}, max: {x.max()}')

    display(Audio(wav.squeeze(), rate=rate))
        
def extract_id(path):
    decoder_id = str(path)[:-4].split('_')[-1]
    return int(decoder_id)

In [16]:
print('Starting')
matplotlib.use('agg')

checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
assert len(checkpoints) >= 1, "No checkpoints found."
print("Checkpoints found:", len(checkpoints))

model_args = torch.load(checkpoint.parent / 'args.pth')[0]
encoder = wavenet_models.Encoder(model_args)
encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
encoder.eval()
encoder = encoder.cuda()

decoders = []
decoder_ids = []
for checkpoint in checkpoints:
    decoder = WaveNet(model_args)
    decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
    decoder.eval()
    decoder = decoder.cuda()
    decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
#     decoder = NVWavenetGenerator(decoder, rate * (split_size // 20), batch_size, Impl.AUTO)

    decoders += [decoder]
    decoder_ids += [extract_id(checkpoint)]

Starting
Checkpoints found: 6


In [17]:
xs = []

for file_path in file_paths:
    data, rate = librosa.load(file_path, sr=16000)
    assert rate == 16000
    data = utils.mu_law(data)
    xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

xs = torch.stack(xs).contiguous()
print(f'xs size: {xs.size()}')

xs size: torch.Size([1, 1, 80000])


In [None]:
yy = {}
with torch.no_grad():
    zz = []
    for xs_batch in torch.split(xs, batch_size):
        zz += [encoder(xs_batch)]
    zz = torch.cat(zz, dim=0)

    with utils.timeit("Generation timer"):
        for i, decoder_id in enumerate(decoder_ids):
            yy[decoder_id] = []
            decoder = decoders[i]
            for zz_batch in torch.split(zz, batch_size):
                print(zz_batch.shape)
                splits = torch.split(zz_batch, split_size, -1)
                audio_data = []
                decoder.reset()
                for cond in tqdm.tqdm_notebook(splits):
                    audio_data += [decoder.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[decoder_id] += [audio_data]
            yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
            break

torch.Size([1, 64, 100])


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

  probabilities = F.softmax(prediction)
Generating: 100%|██████████| 20/20 [32:46<00:00, 98.33s/it] 
Generating:  55%|█████▌    | 11/20 [18:39<16:38, 110.90s/it]

In [None]:
for decoder_ix, decoder_result in yy.items():
    for sample_result, filepath in zip(decoder_result, file_paths):
        disp(sample_result, decoder_ix)

In [None]:
Audio(file_paths[0])