In [51]:
import numpy as np
import itertools

from scipy import linalg
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib

from sklearn import mixture

from pathlib import Path
import torch
import h5py
import tqdm
from IPython.display import Audio, display

import sys
sys.path += ['../music-translation/src']

import utils
import wavenet_models
from utils import save_audio
from wavenet import WaveNet
from wavenet_generator import WavenetGenerator
from nv_wavenet_generator import NVWavenetGenerator
from nv_wavenet_generator import Impl

In [2]:
encoded = []
for directory in Path('encoded-musicnet/encoded').iterdir():
    for path in directory.iterdir():
        encoded += [torch.load(path)]
encoded = torch.cat(encoded, dim=0)
flattened = torch.flatten(encoded, 1)
flattened = flattened.cpu().numpy()
print(flattened.shape)

(203, 12800)


In [4]:
#choosing model

lowest_bic = np.infty
bic = []
n_components_range = range(1, 10)
cv_types = ['spherical', 'tied', 'diag', 'full']
for cv_type in cv_types:
    for n_components in n_components_range:
        # Fit a Gaussian mixture with EM
        gmm = mixture.GaussianMixture(n_components=n_components,
                                      covariance_type=cv_type, reg_covar = 1e-4)
        gmm.fit(flattened)
        bic.append(gmm.bic(flattened))
        if bic[-1] < lowest_bic:
            lowest_bic = bic[-1]
            best_gmm = gmm

In [52]:
samples = []
for i in range(2):
    samples += [torch.from_numpy(best_gmm.sample()[0].reshape(1, 64, 200))]

In [53]:
checkpoint = Path('../music-translation/checkpoints/pretrained_musicnet/bestmodel')
decoders = [0, 1, 2, 3, 4, 5]
batch_size = 1
rate = 16000
split_size = 20



def disp(x, decoder_ix):
    wav = utils.inv_mu_law(x.cpu().numpy())
    print(f'Decoder: {decoder_ix}')
    print(f'X min: {x.min()}, max: {x.max()}')

    display(Audio(wav.squeeze(), rate=rate))
        
def extract_id(path):
    decoder_id = str(path)[:-4].split('_')[-1]
    return int(decoder_id)



print('Starting')
matplotlib.use('agg')

checkpoints = checkpoint.parent.glob(checkpoint.name + '_*.pth')
checkpoints = [c for c in checkpoints if extract_id(c) in decoders]
assert len(checkpoints) >= 1, "No checkpoints found."

model_args = torch.load(checkpoint.parent / 'args.pth')[0]

decoders = []
decoder_ids = []
for checkpoint in checkpoints:
    decoder = WaveNet(model_args)
    decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
    decoder.eval()
    decoder = decoder.cuda()
    decoder = WavenetGenerator(decoder, batch_size, wav_freq=rate)
    
    decoders += [decoder]
    decoder_ids += [extract_id(checkpoint)]

Starting


In [None]:
#decoding randomly generated vectors

yy = {}
with torch.no_grad():
    zz = []
    for vector in samples:
        zz += [vector]
    zz = torch.cat(zz, dim=0).float().cuda()
    print(zz.shape)

    with utils.timeit("Generation timer"):
        for i, decoder_id in enumerate(decoder_ids):
            yy[decoder_id] = []
            decoder = decoders[i]
            for zz_batch in torch.split(zz, batch_size):
                print(zz_batch.shape)
                splits = torch.split(zz_batch, split_size, -1)
                audio_data = []
                decoder.reset()
                for cond in tqdm.tqdm_notebook(splits):
                    audio_data += [decoder.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[decoder_id] += [audio_data]
            yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)

torch.Size([2, 64, 200])
torch.Size([1, 64, 200])


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

  probabilities = F.softmax(prediction)
Generating: 100%|██████████| 20/20 [10:30<00:00, 31.55s/it]
Generating: 100%|██████████| 20/20 [10:24<00:00, 31.23s/it]
Generating: 100%|██████████| 20/20 [10:20<00:00, 31.05s/it]
Generating: 100%|██████████| 20/20 [10:26<00:00, 31.32s/it]
Generating: 100%|██████████| 20/20 [10:18<00:00, 30.93s/it]
Generating: 100%|██████████| 20/20 [10:18<00:00, 30.90s/it]
Generating: 100%|██████████| 20/20 [10:25<00:00, 31.28s/it]
Generating:   0%|          | 0/20 [00:00<?, ?it/s]