In [46]:
from pathlib import Path
import librosa
import torch
from argparse import ArgumentParser
import matplotlib
import h5py
import tqdm
from IPython.display import Audio, display

import numpy as np

import sys
sys.path += ['../music-translation/src']

import utils
import wavenet_models
from utils import save_audio
from wavenet import WaveNet
from wavenet_generator import WavenetGenerator
from nv_wavenet_generator import NVWavenetGenerator
from nv_wavenet_generator import Impl

In [47]:
encoded = []
for directory in Path('encoded-musicnet/encoded').iterdir():
    for path in directory.iterdir():
        encoded += [torch.load(path)]
encoded = torch.cat(encoded, dim=0)
print(encoded)

tensor([[[ 8.7675e-02, -5.8459e-02, -7.1272e-02,  ..., -5.4422e-02,
          -1.3931e-01, -1.1900e-01],
         [-7.6073e-01, -7.4693e-01, -2.5110e-01,  ...,  7.0855e-01,
           7.1958e-01,  6.2726e-01],
         [ 7.0419e-02, -7.1841e-03, -8.3471e-02,  ..., -2.2622e-01,
          -3.5729e-01, -2.1907e-01],
         ...,
         [-2.4616e-01, -1.5479e-01,  1.1283e-01,  ..., -4.9195e-01,
          -5.8850e-01, -6.0281e-01],
         [-1.2716e-02,  1.3529e-01,  2.4566e-01,  ..., -5.3472e-01,
          -6.0958e-01, -6.3094e-01],
         [-1.3738e-01,  2.4088e-02,  1.2584e-01,  ...,  1.0179e-01,
           1.7099e-01,  1.2984e-01]],

        [[ 1.5184e-01,  1.3664e-01,  1.7382e-01,  ..., -4.5771e-02,
          -1.2188e-01, -1.2494e-01],
         [-6.4059e-01, -6.0842e-01, -6.0881e-01,  ...,  1.9787e-01,
           9.2620e-02, -1.1554e-01],
         [-7.6572e-03,  4.8237e-02,  4.3318e-02,  ..., -2.3059e-01,
          -1.6698e-01, -9.5024e-02],
         ...,
         [-3.1849e-01, -4

In [48]:
max_val = torch.max(encoded).item()
min_val = torch.min(encoded).item()

mean = torch.mean(encoded).item()
diffs = encoded - mean
var = torch.mean(torch.pow(diffs, 2.0)).item()
print("max", max_val)
print("min", min_val)
print("mean", mean)
print("var", var)

max 2.077653408050537
min -1.7897802591323853
mean -0.008956626988947392
var 0.06574355810880661


In [49]:
noise_vectors = []
for i in range(5):
    noise = np.random.normal(mean, var**0.5, (1, 64, 200))
    noise = torch.from_numpy(noise)
    noise_vectors += [noise]

In [50]:
checkpoint = Path('../music-translation/checkpoints/pretrained_musicnet/bestmodel')
decoders = [0, 1, 2, 3, 4, 5]
batch_size = 1
rate = 16000
split_size = 20



def disp(x, decoder_ix):
    wav = utils.inv_mu_law(x.cpu().numpy())
    print(f'Decoder: {decoder_ix}')
    print(f'X min: {x.min()}, max: {x.max()}')

    display(Audio(wav.squeeze(), rate=rate))
        
def extract_id(path):
    decoder_id = str(path)[:-4].split('_')[-1]
    return int(decoder_id)



print('Starting')
matplotlib.use('agg')

checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
assert len(checkpoints) >= 1, "No checkpoints found."

model_args = torch.load(checkpoint.parent / 'args.pth')[0]

decoders = []
decoder_ids = []
for checkpoint in checkpoints:
    decoder = WaveNet(model_args)
    decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
    decoder.eval()
    decoder = decoder.cuda()
    decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
    
    decoders += [decoder]
    decoder_ids += [extract_id(checkpoint)]

Starting


In [None]:
yy = {}
with torch.no_grad():
    zz = []
    for vector in noise_vectors:
        zz += [vector]
    zz = torch.cat(zz, dim=0).float().cuda()
    print(zz.shape)

    with utils.timeit("Generation timer"):
        for i, decoder_id in enumerate(decoder_ids):
            yy[decoder_id] = []
            decoder = decoders[i]
            for zz_batch in torch.split(zz, batch_size):
                print(zz_batch.shape)
                splits = torch.split(zz_batch, split_size, -1)
                audio_data = []
                decoder.reset()
                for cond in tqdm.tqdm_notebook(splits):
                    audio_data += [decoder.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[decoder_id] += [audio_data]
            yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)

torch.Size([5, 64, 200])
torch.Size([1, 64, 200])


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

  probabilities = F.softmax(prediction)
Generating:  50%|█████     | 10/20 [07:35<07:45, 46.57s/it]