In [15]:
from pathlib import Path
import librosa
import torch
from argparse import ArgumentParser
import matplotlib
import h5py
import tqdm
from IPython.display import Audio, display

import numpy as np

import sys
sys.path += ['../music-translation/src']

import utils
import wavenet_models
from utils import save_audio
from wavenet import WaveNet
from wavenet_generator import WavenetGenerator
from nv_wavenet_generator import NVWavenetGenerator
from nv_wavenet_generator import Impl
from scipy.io import wavfile

In [2]:
encoded = []
for directory in Path('encoded-musicnet/encoded').iterdir():
    for path in directory.iterdir():
        encoded += [torch.load(path)]
encoded = torch.cat(encoded, dim=0)

In [3]:
max_val = torch.max(encoded).item()
min_val = torch.min(encoded).item()

mean = torch.mean(encoded).item()
diffs = encoded - mean
var = torch.mean(torch.pow(diffs, 2.0)).item()
print("max", max_val)
print("min", min_val)
print("mean", mean)
print("var", var)

max 2.077653408050537
min -1.7897802591323853
mean -0.008956626988947392
var 0.06574355810880661


In [4]:
noise_vectors = []
for i in range(3):
    noise = np.random.normal(mean, var**0.5, (1, 64, 200))
    noise = torch.from_numpy(noise)
    noise_vectors += [noise]

In [5]:
checkpoint = Path('../music-translation/checkpoints/pretrained_musicnet/bestmodel')
decoders = [0, 1, 2, 3, 4, 5]
batch_size = 1
rate = 16000
split_size = 20



def disp(x, decoder_ix):
    wav = utils.inv_mu_law(x.cpu().numpy())
    print(f'Decoder: {decoder_ix}')
    print(f'X min: {x.min()}, max: {x.max()}')

    display(Audio(wav.squeeze(), rate=rate))
        
def extract_id(path):
    decoder_id = str(path)[:-4].split('_')[-1]
    return int(decoder_id)



print('Starting')
matplotlib.use('agg')

checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
assert len(checkpoints) >= 1, "No checkpoints found."

model_args = torch.load(checkpoint.parent / 'args.pth')[0]

decoders = []
decoder_ids = []
for checkpoint in checkpoints:
    decoder = WaveNet(model_args)
    decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
    decoder.eval()
    decoder = decoder.cuda()
    decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
    
    decoders += [decoder]
    decoder_ids += [extract_id(checkpoint)]

Starting


In [6]:
yy = {}
with torch.no_grad():
    zz = []
    for vector in noise_vectors:
        zz += [vector]
    zz = torch.cat(zz, dim=0).float().cuda()
    print(zz.shape)

    with utils.timeit("Generation timer"):
        for i, decoder_id in enumerate(decoder_ids):
            yy[decoder_id] = []
            decoder = decoders[i]
            for zz_batch in torch.split(zz, batch_size):
                print(zz_batch.shape)
                splits = torch.split(zz_batch, split_size, -1)
                audio_data = []
                decoder.reset()
                for cond in tqdm.tqdm_notebook(splits):
                    audio_data += [decoder.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[decoder_id] += [audio_data]
            yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)

torch.Size([3, 64, 200])
torch.Size([1, 64, 200])


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

  probabilities = F.softmax(prediction)
Generating: 100%|██████████| 20/20 [08:21<00:00, 25.05s/it]
Generating: 100%|██████████| 20/20 [08:27<00:00, 25.36s/it]
Generating: 100%|██████████| 20/20 [08:23<00:00, 25.18s/it]
Generating: 100%|██████████| 20/20 [08:21<00:00, 25.06s/it]
Generating: 100%|██████████| 20/20 [08:22<00:00, 25.13s/it]
Generating: 100%|██████████| 20/20 [08:21<00:00, 25.09s/it]
Generating: 100%|██████████| 20/20 [08:24<00:00, 25.21s/it]
Generating: 100%|██████████| 20/20 [08:20<00:00, 25.02s/it]
Generating: 100%|██████████| 20/20 [08:46<00:00, 26.30s/it]
Generating: 100%|██████████| 20/20 [08:51<00:00, 26.55s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:35<00:00, 25.78s/it]
Generating: 100%|██████████| 20/20 [08:43<00:00, 26.15s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.48s/it]
Generating: 100%|██████████| 20/20 [08:23<00:00, 25.16s/it]
Generating: 100%|██████████| 20/20 [08:21<00:00, 25.08s/it]
Generating: 100%|██████████| 20/20 [08:18<00:00, 24.92s/it]
Generating: 100%|██████████| 20/20 [08:18<00:00, 24.94s/it]
Generating: 100%|██████████| 20/20 [08:25<00:00, 25.28s/it]
Generating: 100%|██████████| 20/20 [08:21<00:00, 25.05s/it]
Generating: 100%|██████████| 20/20 [08:37<00:00, 25.90s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:20<00:00, 25.02s/it]
Generating: 100%|██████████| 20/20 [08:37<00:00, 25.86s/it]
Generating: 100%|██████████| 20/20 [08:28<00:00, 25.40s/it]
Generating: 100%|██████████| 20/20 [08:22<00:00, 25.13s/it]
Generating: 100%|██████████| 20/20 [08:18<00:00, 24.92s/it]
Generating: 100%|██████████| 20/20 [08:19<00:00, 24.96s/it]
Generating: 100%|██████████| 20/20 [08:43<00:00, 26.17s/it]
Generating: 100%|██████████| 20/20 [08:17<00:00, 24.87s/it]
Generating: 100%|██████████| 20/20 [08:25<00:00, 25.28s/it]
Generating: 100%|██████████| 20/20 [09:00<00:00, 27.01s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:34<00:00, 25.75s/it]
Generating: 100%|██████████| 20/20 [08:46<00:00, 26.32s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.53s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.68s/it]
Generating: 100%|██████████| 20/20 [08:27<00:00, 25.39s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.49s/it]
Generating: 100%|██████████| 20/20 [08:36<00:00, 25.80s/it]
Generating: 100%|██████████| 20/20 [08:38<00:00, 25.91s/it]
Generating: 100%|██████████| 20/20 [08:44<00:00, 26.23s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.70s/it]



torch.Size([1, 64, 200])


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:37<00:00, 25.88s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.74s/it]
Generating: 100%|██████████| 20/20 [08:55<00:00, 26.76s/it]
Generating: 100%|██████████| 20/20 [08:48<00:00, 26.42s/it]
Generating: 100%|██████████| 20/20 [08:31<00:00, 25.58s/it]
Generating: 100%|██████████| 20/20 [08:39<00:00, 25.98s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.54s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.29s/it]
Generating: 100%|██████████| 20/20 [08:43<00:00, 26.17s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.50s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:28<00:00, 25.44s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.47s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.67s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.45s/it]
Generating: 100%|██████████| 20/20 [08:28<00:00, 25.42s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.47s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.51s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.48s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.71s/it]
Generating: 100%|██████████| 20/20 [08:37<00:00, 25.88s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:44<00:00, 26.25s/it]
Generating: 100%|██████████| 20/20 [08:37<00:00, 25.87s/it]
Generating: 100%|██████████| 20/20 [08:40<00:00, 26.05s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.72s/it]
Generating: 100%|██████████| 20/20 [09:05<00:00, 27.29s/it]
Generating: 100%|██████████| 20/20 [08:41<00:00, 26.07s/it]
Generating: 100%|██████████| 20/20 [08:39<00:00, 25.98s/it]
Generating: 100%|██████████| 20/20 [08:36<00:00, 25.83s/it]
Generating: 100%|██████████| 20/20 [08:52<00:00, 26.62s/it]
Generating: 100%|██████████| 20/20 [09:29<00:00, 28.45s/it]



torch.Size([1, 64, 200])


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [09:18<00:00, 27.90s/it]
Generating: 100%|██████████| 20/20 [08:43<00:00, 26.18s/it]
Generating: 100%|██████████| 20/20 [08:52<00:00, 26.60s/it]
Generating: 100%|██████████| 20/20 [08:51<00:00, 26.60s/it]
Generating: 100%|██████████| 20/20 [08:45<00:00, 26.25s/it]
Generating: 100%|██████████| 20/20 [08:43<00:00, 26.19s/it]
Generating: 100%|██████████| 20/20 [08:42<00:00, 26.12s/it]
Generating: 100%|██████████| 20/20 [08:44<00:00, 26.23s/it]
Generating: 100%|██████████| 20/20 [08:43<00:00, 26.18s/it]
Generating: 100%|██████████| 20/20 [08:50<00:00, 26.53s/it]



torch.Size([1, 64, 200])


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [12:21<00:00, 37.06s/it]
Generating: 100%|██████████| 20/20 [08:25<00:00, 25.27s/it]
Generating: 100%|██████████| 20/20 [08:27<00:00, 25.39s/it]
Generating: 100%|██████████| 20/20 [08:27<00:00, 25.37s/it]
Generating: 100%|██████████| 20/20 [08:32<00:00, 25.63s/it]
Generating: 100%|██████████| 20/20 [08:32<00:00, 25.62s/it]
Generating: 100%|██████████| 20/20 [08:32<00:00, 25.60s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.66s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.68s/it]
Generating: 100%|██████████| 20/20 [08:32<00:00, 25.65s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:33<00:00, 25.67s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.71s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.71s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.69s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.73s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.71s/it]
Generating: 100%|██████████| 20/20 [08:35<00:00, 25.77s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.72s/it]
Generating: 100%|██████████| 20/20 [08:45<00:00, 26.25s/it]
Generating: 100%|██████████| 20/20 [09:17<00:00, 27.89s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [09:21<00:00, 28.07s/it]
Generating: 100%|██████████| 20/20 [09:22<00:00, 28.13s/it]
Generating: 100%|██████████| 20/20 [08:52<00:00, 26.64s/it]
Generating: 100%|██████████| 20/20 [09:00<00:00, 27.02s/it]
Generating: 100%|██████████| 20/20 [14:10<00:00, 42.52s/it]
Generating: 100%|██████████| 20/20 [08:55<00:00, 26.79s/it]
Generating: 100%|██████████| 20/20 [09:24<00:00, 28.22s/it]
Generating: 100%|██████████| 20/20 [10:08<00:00, 30.42s/it]
Generating: 100%|██████████| 20/20 [10:01<00:00, 30.05s/it]
Generating: 100%|██████████| 20/20 [10:01<00:00, 30.08s/it]



torch.Size([1, 64, 200])


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [11:31<00:00, 34.58s/it]
Generating: 100%|██████████| 20/20 [10:37<00:00, 31.87s/it]
Generating: 100%|██████████| 20/20 [11:05<00:00, 33.29s/it]
Generating: 100%|██████████| 20/20 [09:28<00:00, 28.42s/it]
Generating: 100%|██████████| 20/20 [08:38<00:00, 25.92s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.46s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.54s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.50s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.52s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.51s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:32<00:00, 25.61s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.47s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.53s/it]
Generating: 100%|██████████| 20/20 [08:31<00:00, 25.58s/it]
Generating: 100%|██████████| 20/20 [08:31<00:00, 25.59s/it]
Generating: 100%|██████████| 20/20 [08:38<00:00, 25.94s/it]
Generating: 100%|██████████| 20/20 [10:01<00:00, 30.05s/it]
Generating: 100%|██████████| 20/20 [10:38<00:00, 31.92s/it]
Generating: 100%|██████████| 20/20 [10:44<00:00, 32.24s/it]
Generating: 100%|██████████| 20/20 [10:20<00:00, 31.05s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:49<00:00, 26.47s/it]
Generating: 100%|██████████| 20/20 [08:36<00:00, 25.85s/it]
Generating: 100%|██████████| 20/20 [08:41<00:00, 26.07s/it]
Generating: 100%|██████████| 20/20 [08:40<00:00, 26.04s/it]
Generating: 100%|██████████| 20/20 [08:39<00:00, 25.97s/it]
Generating: 100%|██████████| 20/20 [08:37<00:00, 25.87s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.51s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.50s/it]
Generating: 100%|██████████| 20/20 [08:28<00:00, 25.42s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.48s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:33<00:00, 25.67s/it]
Generating: 100%|██████████| 20/20 [08:35<00:00, 25.76s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.53s/it]
Generating: 100%|██████████| 20/20 [08:29<00:00, 25.48s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.53s/it]
Generating: 100%|██████████| 20/20 [08:31<00:00, 25.57s/it]
Generating: 100%|██████████| 20/20 [08:31<00:00, 25.57s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.53s/it]
Generating: 100%|██████████| 20/20 [08:30<00:00, 25.53s/it]
Generating: 100%|██████████| 20/20 [08:32<00:00, 25.62s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [08:37<00:00, 25.87s/it]
Generating: 100%|██████████| 20/20 [08:31<00:00, 25.56s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.71s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.70s/it]
Generating: 100%|██████████| 20/20 [08:34<00:00, 25.72s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.66s/it]
Generating: 100%|██████████| 20/20 [08:33<00:00, 25.66s/it]
Generating: 100%|██████████| 20/20 [08:49<00:00, 26.49s/it]
Generating: 100%|██████████| 20/20 [09:24<00:00, 28.24s/it]
Generating: 100%|██████████| 20/20 [09:23<00:00, 28.16s/it]


torch.Size([1, 64, 200])





HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [09:21<00:00, 28.08s/it]
Generating: 100%|██████████| 20/20 [09:21<00:00, 28.10s/it]
Generating: 100%|██████████| 20/20 [09:22<00:00, 28.14s/it]
Generating: 100%|██████████| 20/20 [09:23<00:00, 28.18s/it]
Generating: 100%|██████████| 20/20 [09:24<00:00, 28.22s/it]
Generating: 100%|██████████| 20/20 [09:23<00:00, 28.17s/it]
Generating: 100%|██████████| 20/20 [09:24<00:00, 28.23s/it]
Generating: 100%|██████████| 20/20 [09:23<00:00, 28.17s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.29s/it]
Generating: 100%|██████████| 20/20 [09:23<00:00, 28.18s/it]



torch.Size([1, 64, 200])


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Generating: 100%|██████████| 20/20 [09:25<00:00, 28.28s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.29s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.26s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.28s/it]
Generating: 100%|██████████| 20/20 [09:27<00:00, 28.38s/it]
Generating: 100%|██████████| 20/20 [09:28<00:00, 28.43s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.28s/it]
Generating: 100%|██████████| 20/20 [09:28<00:00, 28.41s/it]
Generating: 100%|██████████| 20/20 [09:25<00:00, 28.25s/it]
Generating: 100%|██████████| 20/20 [09:29<00:00, 28.48s/it]


Generation timer took 95904656.8095684 ms





In [22]:
for decoder_ix, decoder_result in yy.items():
    i=1
    for sample_result, noise_vector in zip(decoder_result, noise_vectors):
        disp(sample_result, decoder_ix)
        wav = utils.inv_mu_law(sample_result.cpu().numpy())
        save_audio(wav.squeeze(), Path("results/noise_d" + str(decoder_ix) + "-" + str(i) + ".wav"), rate)
        i++

Decoder: 3
X min: 42, max: 211


Decoder: 3
X min: 44, max: 210


Decoder: 3
X min: 49, max: 207


Decoder: 2
X min: 54, max: 202


Decoder: 2
X min: 55, max: 202


Decoder: 2
X min: 53, max: 205


Decoder: 1
X min: 0, max: 247


Decoder: 1
X min: 2, max: 244


Decoder: 1
X min: 0, max: 249


Decoder: 0
X min: 39, max: 212


Decoder: 0
X min: 38, max: 213


Decoder: 0
X min: 37, max: 210


Decoder: 5
X min: 11, max: 254


Decoder: 5
X min: 9, max: 252


Decoder: 5
X min: 10, max: 254


Decoder: 4
X min: 10, max: 242


Decoder: 4
X min: 19, max: 235


Decoder: 4
X min: 17, max: 239
