In [None]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import pathlib
import pickle
import numpy as np
import librosa

import IPython
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from tqdm import tqdm

from vqvae import VQVAE, InferenceVQVAE
from scheduler import CycleScheduler

from nsynth_dataset import NSynthH5Dataset
from GANsynth_pytorch.pytorch_nsynth_lib.nsynth import (
    NSynth, get_mel_spectrogram_and_IF)

import GANsynth_pytorch

In [None]:
device = 'cuda:1'
valid_pitch_range = [24, 84]
HOP_LENGTH = 512


dataset_type = 'wav'
nsynth_dataset_path = pathlib.Path('~/code/data/nsynth/valid/json_wav').expanduser()
if dataset_type == 'hdf5':
    nsynth_dataset = NSynthDataset(
        root_path=nsynth_dataset_path,
        use_mel_frequency_scale=True)
elif dataset_type == 'wav':
    def chained_transform(sample):
                mel_spec, mel_IF = get_mel_spectrogram_and_IF(
                    sample, hop_length=HOP_LENGTH)
                mel_spec_and_IF_as_image_tensor = NSynthH5Dataset._to_image(
                    [a.astype(np.float32)
                     for a in [mel_spec, mel_IF]])
                return mel_spec_and_IF_as_image_tensor
    to_mel_spec_and_if = transforms.Lambda(chained_transform)
    nsynth_dataset = NSynth(
        root=str(nsynth_dataset_path),
        transform=chained_transform,
        valid_pitch_range=valid_pitch_range,
        categorical_field_list=[],
        convert_to_float=True)
else:
    raise ValueError("Unrecognized dataset type")

num_workers = 0
batch_size = 64
loader = DataLoader(nsynth_dataset, batch_size=1,
                    num_workers=num_workers, shuffle=True)
in_channel = 2

vqvae_decoder_activation = None

dataloader_for_gansynth_normalization = None
normalizer_statistics = None

disable_input_normalization = True
use_precomputed_normalization_statistics = True
if not disable_input_normalization:
    if use_precomputed_normalization_statistics:
        normalizer_statistics_path = nsynth_dataset_path / '../normalization_statistics.pkl'
        with open(normalizer_statistics_path, 'rb') as f:
            normalizer_statistics = pickle.load(f)
    else:
        dataloader_for_gansynth_normalization = loader

print("Initializing model")
vqvae_path = pathlib.Path('../vq-vae-2-pytorch/checkpoints/086f45/vqvae_nsynth_020.pt')
vqvae = VQVAE(in_channel=in_channel,
              decoder_output_activation=vqvae_decoder_activation,
              normalizer_statistics=normalizer_statistics,
              dataloader_for_gansynth_normalization=dataloader_for_gansynth_normalization,
              groups=2
              )
vqvae.load_state_dict(torch.load(vqvae_path))
vqvae.to(device)
inference_vqvae = InferenceVQVAE(vqvae, device)
# model = vqvae.to(device)

In [None]:
vqvae.eval()
iterator = iter(loader)

In [None]:
samples, reconstructions = inference_vqvae.sample_reconstructions(loader)

sample_audio = inference_vqvae.mag_and_IF_to_audio(samples,
                                                   use_mel_frequency=True)
print("Original audio")
IPython.display.display(IPython.display.Audio(sample_audio, rate=16000))

reconstructions_audio = inference_vqvae.mag_and_IF_to_audio(reconstructions,
                                                            use_mel_frequency=True)
print("Reconstructed audio")
IPython.display.display(IPython.display.Audio(reconstructions_audio, rate=16000))

In [None]:
# display computed codes
codes_top = vqvae.quantize_t.embed.data.cpu()
plt.matshow(codes_top - codes_top.mean(dim=0, keepdim=True))
plt.title('Top layer (1st, unconditioned layer) codes', pad=20)
plt.show()

codes_bottom = vqvae.quantize_b.embed.data.cpu()
plt.matshow(codes_bottom - codes_bottom.mean(dim=0, keepdim=True))
plt.title('Bottom layer (2nd, conditioned layer) codes', pad=20)
plt.show()

In [None]:
# code distribution over the training dataset
num_embeddings_top = vqvae.quantize_t.n_embed
num_embeddings_bottom = vqvae.quantize_b.n_embed
code_usage_top = np.zeros(num_embeddings_top)
code_usage_bottom = np.zeros(num_embeddings_bottom)

def count_uses(codes_batch, code_usage_matrix):
    for tensor in codes_batch:
        for row in tensor:
            for value in row:
                code_usage_matrix[int(value)] += 1

with torch.no_grad():
    for batch, pitch in tqdm(loader):
        _, _, _, id_t, id_b = vqvae.encode(batch.to(device))
        count_uses(id_t, code_usage_top)
        count_uses(id_b, code_usage_bottom)

In [None]:
plt.subplot(1, 2, 1, title="Bottom layer")
plt.bar(range(num_embeddings_bottom), height=list(np.log(code_usage_bottom + 1)))
plt.subplot(1, 2, 2, title="Top layer")
plt.bar(range(num_embeddings_top), height=list(np.log(code_usage_top + 1)))

plt.suptitle("Code usage (in log-scale)")
plt.show()