In [None]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import pathlib
import pickle
import numpy as np
import librosa

import IPython
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from tqdm import tqdm

from vqvae import VQVAE
from scheduler import CycleScheduler

from nsynth_dataset import NSynthDataset

import GANsynth_pytorch

In [None]:
device = 'cuda:0'
nsynth_dataset_path = pathlib.Path('~/code/data/nsynth/valid/hdf5').expanduser()
nsynth_dataset = NSynthDataset(
    root_path=nsynth_dataset_path,
    use_mel_frequency_scale=True)
num_workers = 0
batch_size = 64
loader = DataLoader(nsynth_dataset, batch_size=1,
                    num_workers=num_workers, shuffle=True)
in_channel = 2

vqvae_decoder_activation = None

dataloader_for_gansynth_normalization = None
normalizer_statistics = None

use_precomputed_normalization_statistics = True
if use_precomputed_normalization_statistics:
    normalizer_statistics_path = nsynth_dataset_path / '../normalization_statistics.pkl'
    with open(normalizer_statistics_path, 'rb') as f:
        normalizer_statistics = pickle.load(f)
else:
    dataloader_for_gansynth_normalization = loader

print("Initializing model")
vqvae_path = pathlib.Path('../vq-vae-2-pytorch/checkpoint/vqvae_nsynth_560.pt')
vqvae = VQVAE(in_channel=in_channel,
              decoder_output_activation=vqvae_decoder_activation,
              normalizer_statistics=normalizer_statistics,
              dataloader_for_gansynth_normalization=dataloader_for_gansynth_normalization
              )
vqvae.load_state_dict(torch.load(vqvae_path))
model = vqvae.to(device)

In [None]:
vqvae.eval()
iterator = iter(loader)
data_normalizer = vqvae.data_normalizer

In [None]:
def mag_plus_phase(mag, IF):
    mag =  np.exp(mag) - 1.0e-6
    reconstruct_magnitude = np.abs(mag)

    reconstruct_phase_angle = np.cumsum(IF * np.pi, axis=1)
    stft = GANsynth_pytorch.phase_operation.polar2rect(reconstruct_magnitude, reconstruct_phase_angle)
    inverse = librosa.istft(stft, hop_length = 512, win_length=2048, window = 'hann')

    return inverse

def convert_representation(sample):
    channel_dimension = 0
    spec = sample.select(channel_dimension, 0).data.cpu().numpy()
    IF = sample.select(channel_dimension, 1).data.cpu().numpy()
    back_mag, back_IF = GANsynth_pytorch.spectrograms_helper.melspecgrams_to_specgrams(spec, IF)
    back_mag = np.vstack((back_mag,back_mag[1023]))
    back_IF = np.vstack((back_IF,back_IF[1023]))
    audio = mag_plus_phase(back_mag,back_IF)
    return audio

In [None]:
sample = next(iterator)[0][0][0]
sample_audio = convert_representation(sample)
print("Original audio")
IPython.display.display(IPython.display.Audio(sample_audio, rate=16000))

reconstructed = vqvae.forward(sample.to(device))[0][0]
reconstructed_audio = convert_representation(reconstructed)
print("Reconstructed audio")
IPython.display.display(IPython.display.Audio(reconstructed_audio, rate=16000))

In [None]:
# display computed codes
codes_top = vqvae.quantize_t.embed.data.cpu().numpy()
plt.matshow(codes_top)
plt.title('Top layer (1st, unconditioned layer) codes')

codes_bottom = vqvae.quantize_b.embed.data.cpu().numpy()
plt.matshow(codes_bottom)
plt.title('Bottom layer (2nd, conditioned layer) codes')

In [None]:
# generate random codes for synthesis