In [None]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import pathlib
import json
import pickle
import numpy as np
import librosa

import IPython
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from tqdm import tqdm

from vqvae import VQVAE, InferenceVQVAE
from scheduler import CycleScheduler

from nsynth_dataset import NSynthH5Dataset
from GANsynth_pytorch.pytorch_nsynth_lib.nsynth import (
    NSynth, make_to_mel_spec_and_IF_image_transform,
    WavToSpectrogramDataLoader)

import GANsynth_pytorch
from GANsynth_pytorch.utils import plots

In [None]:
device = 'cpu'
valid_pitch_range = [24, 84]
HOP_LENGTH = 512
N_FFT = 2048
FS_HZ = 16000
USE_MEL_FREQUENCY = True


dataset_type = 'wav'
nsynth_dataset_path = pathlib.Path('~/code/data/nsynth/valid/json_wav').expanduser()
if dataset_type == 'hdf5':
    nsynth_dataset = NSynthDataset(
        root_path=nsynth_dataset_path,
        use_mel_frequency_scale=True)
elif dataset_type == 'wav':
    nsynth_dataset = NSynth(
        root=str(nsynth_dataset_path),
        valid_pitch_range=valid_pitch_range,
        categorical_field_list=[],
        convert_to_float=True)
else:
    raise ValueError("Unrecognized dataset type")

num_workers = 4
batch_size = 64
loader = WavToSpectrogramDataLoader(nsynth_dataset, batch_size=batch_size,
                                    num_workers=num_workers, shuffle=True,
                                    device=device)
single_sample_loader = WavToSpectrogramDataLoader(nsynth_dataset, batch_size=1,
                                                  num_workers=0, shuffle=True,
                                                  device=device)

in_channel = 2

vqvae_decoder_activation = None

dataloader_for_gansynth_normalization = None
normalizer_statistics = None

disable_input_normalization = True
use_precomputed_normalization_statistics = True
if not disable_input_normalization:
    if use_precomputed_normalization_statistics:
        normalizer_statistics_path = nsynth_dataset_path / '../normalization_statistics.pkl'
        with open(normalizer_statistics_path, 'rb') as f:
            normalizer_statistics = pickle.load(f)
    else:
        dataloader_for_gansynth_normalization = loader

print("Initializing model")
run_ID = '20191022-155506-0839e0'
checkpoint_epoch = '132'
vqvae_path = pathlib.Path(f'../vq-vae-2-pytorch/checkpoints/{run_ID}/vqvae_nsynth_{checkpoint_epoch}.pt')
vqvae_parameters_path = pathlib.Path(f'../vq-vae-2-pytorch/checkpoints/{run_ID}/model_parameters.json')
with open(vqvae_parameters_path, 'r') as f:
    vqvae_parameters = json.load(f)

vqvae = VQVAE(decoder_output_activation=vqvae_decoder_activation,
              dataloader_for_gansynth_normalization=dataloader_for_gansynth_normalization,
              normalizer_statistics=normalizer_statistics,
              **vqvae_parameters
             )
vqvae.load_state_dict(torch.load(vqvae_path, map_location=device))
vqvae.to(device)
vqvae.eval()
    
inference_vqvae = InferenceVQVAE(vqvae, device, hop_length=HOP_LENGTH, n_fft=N_FFT)

def make_audio(mag_and_IF_batch: torch.Tensor) -> np.ndarray:
    audio_batch = inference_vqvae.mag_and_IF_to_audio(mag_and_IF_batch,
                                                      use_mel_frequency=USE_MEL_FREQUENCY)
    audio_mono_concatenated = audio_batch.flatten().cpu().numpy()
    return audio_mono_concatenated

def make_audio_player(mag_and_IF_batch: torch.Tensor) -> None:
    audio_mono_concatenated = make_audio(mag_and_IF_batch)
    IPython.display.display(IPython.display.Audio(audio_mono_concatenated,
                                                  rate=FS_HZ, normalize=True))
    
def plot_specs_and_IFs(*specs_and_IFs) -> None:    
    num_subplots = 2*len(specs_and_IFs)
    plots_per_row = 12
    num_rows = num_subplots // plots_per_row + (
        1 if num_subplots % plots_per_row != 0 or num_subplots < plots_per_row else 0)
    plt.subplots(num_rows, plots_per_row, figsize=(25, 10*num_rows))
    for subplot_index, spec_and_IF in enumerate(specs_and_IFs):
        spec_and_IF = spec_and_IF.cpu().numpy()
        ax_spec = plt.subplot(1, num_subplots, 1 + 2*subplot_index)
        ax_IF = plt.subplot(1, num_subplots, 1+ 2*subplot_index + 1)
        plots.plot_mel_representations(spec_and_IF[0], spec_and_IF[1],
                                       hop_length=HOP_LENGTH, fs_hz=16000,
                                       ax_spec=ax_spec, ax_IF=ax_IF)
    plt.tight_layout()

In [None]:
samples, reconstructions = inference_vqvae.sample_reconstructions(single_sample_loader)

# using InferenceVQVAE object
sample_audio = inference_vqvae.mag_and_IF_to_audio(samples,
                                                   use_mel_frequency=True)
print("Original audio")
def print_stats(melspec_and_IF_batch: torch.Tensor) -> None:
    logmelspec_batch = melspec_and_IF_batch[:,0]
    melIF_batch = melspec_and_IF_batch[:,1]
    print(f'Log-mel-spec: Mean {logmelspec_batch.mean().item():.2f}, variance {logmelspec_batch.var().item():.2f}')
    print(f'Mel-IF: Mean {melIF_batch.mean().item():.2f}, variance {melIF_batch.var().item():.2f}')

print_stats(samples)
IPython.display.display(IPython.display.Audio(sample_audio.flatten().cpu(), rate=16000))

print("Reconstructed audio")
print_stats(reconstructions)
# using shortcut helper function in the rest of this notebook
make_audio_player(reconstructions)

In [None]:
# reconstruction  without using all hidden layers

with torch.no_grad():
    samples, _ = next(iter(single_sample_loader))
    quant_t, quant_b, *_ = vqvae.encode(samples.to(device))
    decoded_only_bottom = vqvae.decode(torch.zeros(quant_t.shape).to(device), quant_b).cpu()
    decoded_only_top = vqvae.decode(quant_t, torch.zeros(quant_b.shape).to(device)).cpu()
    decoded = vqvae.decode(quant_t, quant_b).cpu()

specs_and_IFs = [samples[0], decoded_only_bottom[0], decoded_only_top[0], decoded[0]]
plot_specs_and_IFs(specs_and_IFs)
# plt.suptitle('Original, Reconstructed using only bottom / only top / top and bottom', pad=20)

In [None]:
# display computed codes
named_code_layers = {'top': vqvae.quantize_t, 'bottom': vqvae.quantize_b}
num_layers = len(named_code_layers.values())
fig, _ = plt.subplots(num_layers, 2, figsize=(20, 5*num_layers))
from scipy.spatial.distance import pdist, squareform

for layer_index, (layer_name, code_layer) in enumerate(named_code_layers.items()):
    codes = code_layer.embed.data.cpu().T
    ax = plt.subplot(num_layers, 2, 1 + 2*layer_index)
    codes_matrix = codes - codes.mean(dim=0, keepdim=True)
    plt.matshow(codes_matrix, fignum=0)
    plt.title(f'{layer_name} layer codes (layer number {layer_index+1})', pad=20)
    ax = plt.subplot(num_layers, 2, 1 + 2*layer_index + 1)
#     codes_selfsimilarity = codes_matrix @ codes_matrix.T
    codes_correlation = np.corrcoef(codes_matrix)
    pairwise_correlations = pdist(codes_matrix + 1e-7, 'correlation')
    
#     print(pairwise_cosine_similarity)
#     plt.matshow(1 - abs(codes_correlation), fignum=0)
    plt.hist(pairwise_correlations, bins=30)
#     plt.colorbar()
    plt.title(f'{str(layer_name).capitalize()} layer codes absolute correlation \n (layer number {layer_index+1})', pad=20)

plt.tight_layout()
plt.show()

In [None]:
# code distribution over the training dataset
num_embeddings_top = vqvae.quantize_t.n_embed
num_embeddings_bottom = vqvae.quantize_b.n_embed
code_usage_top = torch.zeros(num_embeddings_top).long().to(device)
code_usage_bottom = torch.zeros(num_embeddings_bottom).long().to(device)

def count_uses(codes_batch, code_usage_matrix):
    num_codes = len(code_usage_matrix)
    code_usage_batch = codes_batch.flatten().bincount(minlength=num_codes)
    code_usage_matrix += code_usage_batch

with torch.no_grad():
    for batch, pitch in tqdm(loader):
        _, _, _, id_t, id_b, _, _ = vqvae.encode(batch.to(device))
        count_uses(id_t, code_usage_top)
        count_uses(id_b, code_usage_bottom)

plt.subplots(2, 1, figsize=(30, 15))
plt.subplot(2, 1, 1, title="Bottom layer")
plt.bar(range(num_embeddings_bottom), height=list(code_usage_bottom.cpu().numpy()))
plt.subplot(2, 1, 2, title="Top layer")
plt.bar(range(num_embeddings_top), height=list(code_usage_top.cpu().numpy()))
plt.savefig(f'code_usage-epoch_{checkpoint_epoch}.svg')

plt.suptitle("Code usage")
plt.show()

## Embeddings dimension comparison

In this section, we compare the dimension of the training/target data and of the embeddings,
in order to compute and check if the compression ratio actually makes sense.

In [None]:
one_shot_single_sample_iterator = iter(siptngle_sample_loader)
sample_as_batch, _ = next(one_shot_single_sample_iterator)
sample_as_batch
sample = sample_as_batch[0]

with torch.no_grad():
    *results_as_batches, _, _ = vqvae.encode(sample_as_batch.to(device))
    quant_t, quant_b, diff_t_plus_b, id_t, id_b = [
        tensor[0].cpu() for tensor in results_as_batches]

print('Original shape: ', sample.shape)
print('Quant bottom shape: ', quant_b.shape)
print('Quant top shape: ', quant_t.shape)
print('Code bottom shape: ', id_b.shape)
print('Code top shape: ', id_t.shape)
print("Compression ratio: ", (id_b.numel() + id_t.numel()) / sample.numel()) 

# Sound creation

In this section, we attempt to create new sounds by combining codes for heterogeneous sounds. 

In [None]:
two_sample_loader = WavToSpectrogramDataLoader(nsynth_dataset, batch_size=2,
                                               num_workers=2, shuffle=True)
one_shot_two_sample_iterator = iter(two_sample_loader)
with torch.no_grad():
    samples, _ = next(one_shot_two_sample_iterator)
    quant_t, quant_b, _, id_t, id_b, _, _ = vqvae.encode(samples.to(device))

print("Original audios")
make_audio_player(samples)

print('Reconstructions')
with torch.no_grad():
    reconstructions = vqvae.decode_code(id_t, id_b)
make_audio_player(reconstructions)

print('Exchange top and bottom codes')
with torch.no_grad():
    reconstructions_switched_codes = vqvae.decode_code(id_t, id_b.flip(0))
make_audio_player(reconstructions_switched_codes)

plot_specs_and_IFs(*samples[0:2], *reconstructions[0:2], *reconstructions_switched_codes[0:2])

# print('Exchange half of the codes (in the temporal dimension)')
# def switch_temporal_halves(codes_tensor: torch.Tensor) -> torch.Tensor:
#     batch_dim, frequency_dim, temporal_dim = 0, 1, 2
#     duration_n = codes_tensor.shape[temporal_dim]
#     first_half = codes_tensor[:, :, :duration_n//2]
#     second_half = codes_tensor[:, :, duration_n//2:]
#     return torch.cat([first_half.flip(0), second_half], temporal_dim)
# with torch.no_grad():
#     reconstructions_halves_switched = vqvae.decode_code(
#         *[switch_temporal_halves(codes) for codes in [id_t, id_b]])
# make_audio_player(reconstructions_halves_switched)

# print('Set half (temporally) of the latent map to zero')
# def zero_temporal_halves(codes_tensor: torch.Tensor) -> torch.Tensor:
#     batch_dim, frequency_dim, temporal_dim = 0, 1, 2
#     duration_n = codes_tensor.shape[temporal_dim]
#     first_half = codes_tensor[:, :, :duration_n//2]
#     zero_second_half = torch.zeros(codes_tensor[:, :, duration_n//2:].shape, dtype=first_half.dtype).to(device)
#     return torch.cat([first_half, zero_second_half], dim=temporal_dim)
# with torch.no_grad():
#     reconstructions_halves_zeroed = vqvae.decode_code(
#         *[zero_temporal_halves(codes) for codes in [id_t, id_b]])
# make_audio_player(reconstructions_halves_zeroed)

In [None]:
print('Linear interpolation of the embedddings')
def make_interpolations(start_map, end_map, steps):
    ratios = torch.linspace(0, 1, steps).to(device)
    translations = torch.einsum('i, jkl -> ijkl', ratios, (end_map - start_map))
    return start_map + translations

with torch.no_grad():
    num_steps = 5
    interpolations_unquant_t_and_b = [
        make_interpolations(codes[0], codes[1], num_steps)
        for codes in (quant_t, quant_b)]
    print(interpolations_unquant_t_and_b[0][:,0,0])
    print(interpolations_unquant_t_and_b[0].shape)
    print(interpolations_unquant_t_and_b[1].shape)
    
    interpolations_quant_t, interpolations_quant_b = [
        quantizer(interpolations_unquant.permute(0, 2, 3, 1))[0].permute(0, 3, 1, 2)
        for quantizer, interpolations_unquant in zip(
            [vqvae.quantize_t, vqvae.quantize_b],
            interpolations_unquant_t_and_b)
    ]
    
    print(interpolations_quant_t[:,0,0])
    interpolations = vqvae.decode(interpolations_quant_t, interpolations_quant_b)
for index, interpolation in enumerate(interpolations):
    print(f'Step {index}')
    make_audio_player(interpolation)

# comparing with the sum of the two signals in temporal domain
make_audio_player(reconstructions.sum(0)/2)
specs_and_IFs = [samples[0], samples[1], interpolations[1], interpolations[2], interpolations[3], samples.sum(0)]
plot_specs_and_IFs(*specs_and_IFs)

In [None]:
# WARNING: don't feed latent maps to quantizer by reshaping them, use .permute()
# Otherwise, the results are scrambled!
with torch.no_grad():
    embeddings_t = quant_t[0].reshape(128, 16, 64)
    requantized_embeddings_t = vqvae.quantize_t(embeddings_t)[0]
    display(requantized_embeddings_t - embeddings_t)

### Latent-codes distortion

Here, we slightly modify the latent maps to check the effect on the decoded audio

In [None]:
sample, reconstruction = inference_vqvae.sample_reconstructions(single_sample_loader)

with torch.no_grad():
    *results_as_batches, _, _ = vqvae.encode(sample.to(device))
    quant_t, quant_b, diff_t_plus_b, id_t, id_b = [
        tensor.cpu() for tensor in results_as_batches]

all_code_layers = {'top': id_t, 'bottom': id_b}
code_layers_to_corrupt_names = ['bottom']
code_layers_to_keep_names = list(set(all_code_layers.keys()) - set(code_layers_to_corrupt_names))
num_embeddings = {'top': vqvae.quantize_t.n_embed,
                  'bottom': vqvae.quantize_b.n_embed}

add_values = [+1, -1, +10, -10, +20, -20, +100, -100, +200, -200]
id_plus_corrupted = {
    code_layer_name: torch.cat([all_code_layers[code_layer_name]
                                .add(add_value)
                                .remainder(num_embeddings[code_layer_name])
                                .long()
     for add_value in add_values
    ], 0)
    for code_layer_name in code_layers_to_corrupt_names
}

id_plus_original = {
    code_layer_name: all_code_layers[code_layer_name].repeat(len(add_values), 1, 1)
    for code_layer_name in code_layers_to_keep_names
}

def pick_codes(original, corrupted):
    codes = [(original.get(code_layer_name, None) if original.get(code_layer_name, None) is not None
             else corrupted[code_layer_name]).to(device)
            for code_layer_name in all_code_layers.keys()]
    return codes

# weights = torch.Tensor([0.2, 0.6, 0.2])
# id_stochastic_1 = [
#     id_map.add(torch.multinomial(weights, id_map.numel(), replacement=True).reshape(id_map.shape).add(-1).to(device)).remainder(n_embed).long()
#     for id_map, n_embed in zip([id_t, id_b],
#                                [vqvae.quantize_t.n_embed, vqvae.quantize_b.n_embed])
# ]

# id_stochastic_randn_0_20_corrupted = [
#     id_map.add((torch.randn_like(id_map, dtype=float).to(device) * 20).long()).remainder(n_embed).long()
#     for id_map, n_embed in zip(codes_to_corrupt,
#                                num_embeddings_to_corrupt)
# ]

# print(id_stochastic_randn_0_20_corrupted)
# print(codes_to_keep)
# id_stochastic = id_stochastic_randn_0_20_corrupted + codes_to_keep
# print(id_stochastic)

with torch.no_grad():
    reconstructions_plus = vqvae.decode_code(*pick_codes(id_plus_corrupted, id_plus_original))
#     reconstructions_stochastic_1 = vqvae.decode_code(*id_stochastic_1)
#     reconstructions_stochastic_randn_0_20 = vqvae.decode_code(*id_stochastic)

print('Original')
make_audio_player(sample)

print('Reconstructions')
make_audio_player(reconstruction)
    
for index, add_value in enumerate(add_values):
    print('Uniform add', add_value)
    make_audio_player(reconstructions_plus[index])

# make_audio_player(reconstructions_stochastic_1)

# make_audio_player(reconstructions_stochastic_randn_0_20)

In [None]:
weights = torch.Tensor([0.2, 0.6, 0.2])
id_stochastic = [
    torch.multinomial(weights, id_map.numel(), replacement=True).reshape(id_map.shape).to(device).long()
    for id_map, n_embed in zip([id_t, id_b],
                               [vqvae.quantize_t.n_embed, vqvae.quantize_b.n_embed])
]

with torch.no_grad():
    reconstructions_stochastic = vqvae.decode_code(*id_stochastic)
make_audio_player(reconstructions_stochastic)

# Out-of-domain application