In [1]:
from pathlib import Path
import librosa
import torch
from argparse import ArgumentParser
import matplotlib
import h5py
import tqdm
from IPython.display import Audio, display

import sys
sys.path += ['../music-translation/src']

import utils
import wavenet_models
from utils import save_audio
from wavenet import WaveNet
from wavenet_generator import WavenetGenerator
from nv_wavenet_generator import NVWavenetGenerator
from nv_wavenet_generator import Impl

In [2]:
checkpoint = Path('../music-translation/checkpoints/pretrained_musicnet/bestmodel')
decoders = [0, 1, 2, 3, 4, 5]
batch_size = 1
rate = 16000
split_size = 20
file_paths = [Path('encoded-musicnet/encoded/Bach_Solo_Cello/2217.pt')]

In [3]:
def disp(x, decoder_ix):
    wav = utils.inv_mu_law(x.cpu().numpy())
    print(f'Decoder: {decoder_ix}')
    print(f'X min: {x.min()}, max: {x.max()}')

    display(Audio(wav.squeeze(), rate=rate))
        
def extract_id(path):
    decoder_id = str(path)[:-4].split('_')[-1]
    return int(decoder_id)

In [4]:
print('Starting')
matplotlib.use('agg')

checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
assert len(checkpoints) >= 1, "No checkpoints found."

model_args = torch.load(checkpoint.parent / 'args.pth')[0]

decoders = []
decoder_ids = []
for checkpoint in checkpoints:
    decoder = WaveNet(model_args)
    decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
    decoder.eval()
    decoder = decoder.cuda()
    decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
    
    decoders += [decoder]
    decoder_ids += [extract_id(checkpoint)]

Starting


In [7]:
yy = {}
with torch.no_grad():
    zz = []
    for file_path in file_paths:
        zz += [torch.load(file_path)]
    zz = torch.cat(zz, dim=0)
    print(zz.shape)

    with utils.timeit("Generation timer"):
        for i, decoder_id in enumerate(decoder_ids):
            yy[decoder_id] = []
            decoder = decoders[i]
            for zz_batch in torch.split(zz, batch_size):
                print(zz_batch.shape)
                splits = torch.split(zz_batch, split_size, -1)
                audio_data = []
                decoder.reset()
                for cond in tqdm.tqdm_notebook(splits):
                    audio_data += [decoder.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[decoder_id] += [audio_data]
            yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)

torch.Size([1, 64, 200])
torch.Size([1, 64, 200])


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

  probabilities = F.softmax(prediction)
Generating: 100%|██████████| 20/20 [08:49<00:00, 26.48s/it]
Generating: 100%|██████████| 20/20 [08:45<00:00, 26.29s/it]
Generating: 100%|██████████| 20/20 [09:22<00:00, 28.12s/it]
Generating: 100%|██████████| 20/20 [11:33<00:00, 34.66s/it]
Generating: 100%|██████████| 20/20 [15:44<00:00, 47.23s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.33s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.31s/it]
Generating: 100%|██████████| 20/20 [15:47<00:00, 47.36s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.29s/it]
Generating: 100%|██████████| 20/20 [15:44<00:00, 47.22s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [15:46<00:00, 47.35s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.35s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.34s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.32s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.32s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.31s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.31s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.33s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.33s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.28s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [15:44<00:00, 47.22s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.26s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.27s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.32s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.26s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.31s/it]
Generating: 100%|██████████| 20/20 [15:46<00:00, 47.32s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.29s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.26s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.27s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [15:46<00:00, 47.33s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.27s/it]
Generating: 100%|██████████| 20/20 [15:45<00:00, 47.28s/it]
Generating: 100%|██████████| 20/20 [14:50<00:00, 44.55s/it]
Generating: 100%|██████████| 20/20 [10:34<00:00, 31.74s/it]
Generating: 100%|██████████| 20/20 [10:39<00:00, 31.97s/it]
Generating: 100%|██████████| 20/20 [10:19<00:00, 30.99s/it]
Generating: 100%|██████████| 20/20 [09:29<00:00, 28.45s/it]
Generating: 100%|██████████| 20/20 [10:19<00:00, 31.00s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.27s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [10:25<00:00, 31.26s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.29s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.28s/it]
Generating: 100%|██████████| 20/20 [10:24<00:00, 31.23s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.27s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.29s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.29s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.29s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.32s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.31s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [10:24<00:00, 31.24s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.31s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.31s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.31s/it]
Generating: 100%|██████████| 20/20 [10:29<00:00, 31.46s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.30s/it]
Generating: 100%|██████████| 20/20 [10:58<00:00, 32.93s/it]
Generating: 100%|██████████| 20/20 [09:47<00:00, 29.36s/it]
Generating: 100%|██████████| 20/20 [10:11<00:00, 30.56s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.30s/it]


Generation timer took 46853280.14039993 ms



