In [None]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

import pathlib
import json
import pickle
import numpy as np
import librosa
import math

import IPython
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from tqdm.notebook import tqdm as tqdm_notebook

from vqvae.vqvae import VQVAE
from priors.transformer import VQNSynthTransformer

from pytorch_nsynth import NSynth
from GANsynth_pytorch.loader import (WavToSpectrogramDataLoader,
                                     MaskedPhaseWavToSpectrogramDataLoader)
from utils.misc import get_spectrograms_helper

import GANsynth_pytorch
from GANsynth_pytorch.utils import plots

In [None]:
DEVICE = 'cuda'

print("Initializing model")
vqvae_run_ID = 'VQVAE-20200914-120547-a8f38a'
checkpoint_epoch = '19901'
vqvae_folder = pathlib.Path(f'./data/checkpoints/{vqvae_run_ID}/')
vqvae_weights_path = vqvae_folder / f'vqvae_nsynth_{checkpoint_epoch}.pt'
vqvae_model_parameters_path = vqvae_folder/ 'model_parameters.json'
vqvae_training_parameters_path = vqvae_folder / 'command_line_parameters.json'

with open(vqvae_model_parameters_path, 'r') as f:
    vqvae_model_parameters = json.load(f)
with open(vqvae_training_parameters_path, 'r') as f:
    vqvae_training_parameters = json.load(f)

vqvae = VQVAE.from_parameters_and_weights(
    vqvae_model_parameters_path,
    vqvae_weights_path,
    device=DEVICE
)
vqvae.to(DEVICE)
vqvae.eval()

In [None]:
valid_pitch_range = [24, 84]

dataset_audio_directory_paths = [
        "/home/theis/code/data/nsynth/train/audio",
        "/home/theis/code/data/nsynth/valid/audio"
    ]
validation_dataset_json_data_path = "/home/theis/code/data/nsynth-balanced-split-fixed_seed/valid/examples.json"
    
common_dataset_parameters = {
        'valid_pitch_range': vqvae_training_parameters['valid_pitch_range'],
        'categorical_field_list': [],
        'squeeze_mono_channel': True
    }
nsynth_validation_dataset = NSynth(
    audio_directory_paths=dataset_audio_directory_paths,
    json_data_path=validation_dataset_json_data_path,
    return_full_metadata=False,
    **common_dataset_parameters)

In [None]:
if vqvae_training_parameters['output_spectrogram_threshold']:
    dataloader_class = MaskedPhaseWavToSpectrogramDataLoader
else:
    dataloader_class = WavToSpectrogramDataLoader

spectrograms_helper = get_spectrograms_helper(DEVICE, **vqvae_training_parameters)

BATCH_SIZE = 2

validation_loader = dataloader_class(
    dataset=nsynth_validation_dataset,
    spectrograms_helper=spectrograms_helper,
    batch_size=BATCH_SIZE,
    num_workers=1, shuffle=True,
    pin_memory=False)

def make_audio(mag_and_IF_batch: torch.Tensor) -> np.ndarray:
    audio_batch = spectrograms_helper.to_audio(mag_and_IF_batch)
    audio_mono_concatenated = audio_batch.flatten().cpu().numpy()
    return audio_mono_concatenated

def make_audio_player(mag_and_IF_batch: torch.Tensor) -> None:
    audio_mono_concatenated = make_audio(mag_and_IF_batch)
    IPython.display.display(IPython.display.Audio(
        audio_mono_concatenated,
        rate=spectrograms_helper.fs_hz, normalize=True))
    
def plot_specs_and_IFs(*specs_and_IFs) -> None:    
    num_subplots = 2*len(specs_and_IFs)
    plots_per_row = 12
    num_rows = num_subplots // plots_per_row + (
        1 if num_subplots % plots_per_row != 0 or num_subplots < plots_per_row else 0)
    plt.subplots(num_rows, plots_per_row, figsize=(25, 10*num_rows))
    for subplot_index, spec_and_IF in enumerate(specs_and_IFs):
        spec_and_IF = spec_and_IF.cpu().numpy()
        ax_spec = plt.subplot(1, num_subplots, 1 + 2*subplot_index)
        ax_IF = plt.subplot(1, num_subplots, 1+ 2*subplot_index + 1)
        plots.plot_mel_representations(
            spec_and_IF[0], spec_and_IF[1],
            hop_length=spectrograms_helper.hop_length,
            fs_hz=spectrograms_helper.fs_hz,
            ax_spec=ax_spec, ax_IF=ax_IF)
    plt.tight_layout()

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt



def full_frame(width=None, height=None):
    """Initialize a full-frame matplotlib figure and axes

    Taken from a GitHub Gist by Kile McDonald:
    https://gist.github.com/kylemcdonald/bedcc053db0e7843ef95c531957cb90f
    """
    import matplotlib as mpl
    mpl.rcParams['savefig.pad_inches'] = 0
    figsize = None if width is None else (width, height)
    fig = plt.figure(figsize=figsize)
    ax = plt.axes([0, 0, 1, 1], frameon=False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.autoscale(tight=True)
    return fig, ax


def show_values(pc, fmt="%.2f", **kw):
    '''
    Heatmap with text in each cell with matplotlib's pyplot
    Source: http://stackoverflow.com/a/25074150/395857 
    By HYRY
    '''
    pc.update_scalarmappable()
    ax = pc.get_axes()
    for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

def cm2inch(*tupl):
    '''
    Specify figure size in centimeter in matplotlib
    Source: http://stackoverflow.com/a/22787457/395857
    By gns-ank
    '''
    inch = 2.54
    if type(tupl[0]) == tuple:
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)

def heatmap(AUC, title=None, xlabel='Time steps', ylabel='Frequency', xticklabels=None, yticklabels=None, cmap='YlOrRd',
           height_in=3, width_in=6, plot_colorbar=False):
    '''
    Inspired by:
    - http://stackoverflow.com/a/16124677/395857 
    - http://stackoverflow.com/a/25074150/395857
    '''
    AUC = AUC.detach().cpu().numpy()
    # Plot it out
#     fig, ax = plt.subplots()
    fig = plt.figure(frameon=False)
    fig.set_size_inches(width_in,height_in)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap, vmin=0.0, vmax=512.0)

    # put the major ticks at the middle of each cell
    # set tick labels
    #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
    if xticklabels is not None:
        ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)
        ax.set_xticklabels(xticklabels, minor=False)
    if yticklabels is not None:
        ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
        ax.set_yticklabels(yticklabels, minor=False)
    
    if title is not None:
        # set title and x/y labels
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)      

    # Remove last blank column
    plt.xlim( (0, AUC.shape[1]) )

    # Turn off all the ticks
    ax = plt.gca()
    for t in ax.xaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False
    for t in ax.yaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False

    if plot_colorbar:
        # Add color bar
        plt.colorbar(c)
        
    return fig, ax

    # Add text in each cell 
#     show_values(c)

    # resize 
#     fig = plt.gcf()
#     fig.set_size_inches(cm2inch(40, 20))

def make_spectrogram_image(spectrogram: torch.Tensor,
                           filename: str = 'spectrogram',
                           upsampling_factor: int = 1,
                           ) -> pathlib.Path:
    """Generate and save a png image for the provided spectrogram.

    Assumes melscale frequency axis.

    Arguments:
        spectrogram (torch.Tensor): the mel-scale spectrogram to draw
    Returns:
        output_path (str): the path where the image was written
    """
    fig, ax = full_frame(width=12, height=8)
    upsampled_spectrogram = (
        torch.nn.functional.interpolate(
            spectrogram.unsqueeze(0).unsqueeze(1),
            mode='bilinear',
            scale_factor=upsampling_factor)).squeeze(0).squeeze(0)
    spectrogram_np = upsampled_spectrogram.cpu().numpy()
    librosa.display.specshow(spectrogram_np,
                             #  y_axis='mel',
                             ax=ax,
                             sr=spectrograms_helper.fs_hz * upsampling_factor,
                             cmap='viridis',
                             hop_length=spectrograms_helper.hop_length)

    image_format = 'pdf'
    # output_path = tempfile.mktemp() + '.' + image_format
    output_path = filename + '.' + image_format
    fig.savefig(output_path, format=image_format, dpi=200,
                pad_inches=0, bbox_inches=0)
    fig.clear()
    plt.close()
    return output_path

In [None]:
@torch.no_grad()
def sample_reconstructions(loader=validation_loader):
    vqvae.eval()
    samples, *_ = next(iter(loader))
    samples.to(DEVICE)
    reconstructions, *_, id_top, id_bottom = vqvae.forward(samples)
    return samples, reconstructions, id_top, id_bottom

samples, reconstructions, *_ = sample_reconstructions()

def print_stats(melspec_and_IF_batch: torch.Tensor) -> None:
    logmelspec_batch = melspec_and_IF_batch[:,0]
    melIF_batch = melspec_and_IF_batch[:,1]
    print(f'Log-mel-spec: Mean {logmelspec_batch.mean().item():.2f}, variance {logmelspec_batch.var().item():.2f}')
    print(f'Mel-IF: Mean {melIF_batch.mean().item():.2f}, variance {melIF_batch.var().item():.2f}')

print("Original audio")
print_stats(samples)
make_audio_player(samples)

print("Reconstructed audio")
print_stats(reconstructions)
# using shortcut helper function in the rest of this notebook
make_audio_player(reconstructions)

In [None]:
large_loader = dataloader_class(
    dataset=nsynth_validation_dataset,
    spectrograms_helper=spectrograms_helper,
    batch_size=64,
    num_workers=10, shuffle=True,
    pin_memory=False)
large_batch, *_ = sample_reconstructions(large_loader)
print((large_batch - vqvae.data_normalizer.denormalize(vqvae.data_normalizer.normalize(large_batch))).norm()/large_batch.norm())
print((vqvae.data_normalizer.normalize(0.5)))

In [None]:
from train_vqvae import get_reconstruction_criterion

jukebox_criterion = get_reconstruction_criterion('Jukebox',
                                                spectrograms_helper)

samples, reconstructions, *_ = sample_reconstructions()
for set_name, t in {'Samples': samples, 'Reconstructions': reconstructions}.items():
    print(set_name)
    print("Mean: ", torch.mean(jukebox_criterion.squared_magnitude(t)))
    print("Variance: ", torch.var(jukebox_criterion.squared_magnitude(t)))
    print("Min: ", torch.min(jukebox_criterion.squared_magnitude(t)))
    print("Max: ", torch.max(jukebox_criterion.squared_magnitude(t)))
    print('\n======\n')

print("Jukebox loss: ", jukebox_criterion(samples, reconstructions))

In [None]:
def get_stfts():
    samples, reconstructions, *_ = sample_reconstructions()
    samples_audio = spectrograms_helper.to_audio(samples)
    samples_reconstructions = spectrograms_helper.to_audio(reconstructions)

    spec_mag_samples_list = []
    spec_mag_reconstructions_list = []
    for (n_fft, window_length, window) in zip(jukebox_criterion.n_ffts,
                                              jukebox_criterion.window_lengths,
                                              jukebox_criterion.windows):
        print(n_fft, window_length)
        hop_length = math.ceil((1-jukebox_criterion.overlap_ratio) * window_length)

        spec_mag_samples, spec_mag_reconstructions = (
            jukebox_criterion.magnitude(torch.stft(
                    audio, n_fft=n_fft, hop_length=hop_length,
                    win_length=window_length,
                    window=window, center=False))
                for audio in (samples_audio, samples_reconstructions)
            )
        spec_mag_samples_list.append(spec_mag_samples)
        spec_mag_reconstructions_list.append(spec_mag_reconstructions)
    return spec_mag_samples_list, spec_mag_reconstructions_list
spec_mag_samples_list, spec_mag_reconstructions_list = get_stfts()

In [None]:
scale_index = 2
nn.L1Loss()(spec_mag_samples_list[scale_index], spec_mag_reconstructions_list[scale_index])

# Dump audio samples of reconstructions

In [None]:
with open('./paper/reconstructions.json', 'r') as f:
    reconstructions_details = json.load(f)

reconstructions_note_strs = list(reconstructions_details.keys())
lookup_table = {nsynth_validation_dataset[i][1]['note_str']: i
                for i in range(len(nsynth_validation_dataset))}

In [None]:
reconstructions_samples = [
    nsynth_validation_dataset[lookup_table[note_str]]
    for note_str in reconstructions_note_strs]
reconstructions_loader = dataloader_class(
    dataset=reconstructions_samples,
    spectrograms_helper=spectrograms_helper,
    batch_size=BATCH_SIZE,
    num_workers=8, shuffle=True,
    pin_memory=True)

In [None]:
reconstructions_samples[0]

In [None]:
import soundfile as sf
import json

samples_dir = pathlib.Path('./paper/reconstructions')

samples, reconstructions, id_top, id_bottom, annotations = sample_reconstructions(reconstructions_loader)
NORMALIZE_AUDIO = True
reconstructions_details = {}
for sample, reconstruction, note_str, pitch, instrument_family_str in zip(
        samples, reconstructions,
        annotations['note_str'], annotations['pitch'], annotations['instrument_family_str']):
    for audio_format in ['wav']:
        filename_original = samples_dir / 'audio' / f'{note_str}-original.{audio_format}' 
        sample_audio = make_audio(sample.unsqueeze(0))
        if NORMALIZE_AUDIO:
            sample_audio /= abs(sample_audio.max())
        
        sf.write(filename_original, sample_audio,
                 samplerate=spectrograms_helper.fs_hz, format=audio_format)

        filename_reconstruction = samples_dir / 'audio' / f'{note_str}-reconstruction.{audio_format}'
        reconstruction_audio = make_audio(reconstruction.unsqueeze(0))
        if NORMALIZE_AUDIO:
            reconstruction_audio /= abs(reconstruction_audio.max())
        sf.write(str(filename_reconstruction), reconstruction_audio,
                 samplerate=spectrograms_helper.fs_hz, format=audio_format)
    reconstructions_details[note_str] = {'pitch': int(pitch.data), 'instrument_family_str': instrument_family_str}

with open(samples_dir / 'reconstructions.json', 'w') as f:
    json.dump(reconstructions_details, f)

In [None]:
annotations['note_str']

In [None]:
def write_spectrogram_image_to_file(logmel_and_IF, filename_base, hop_length, fs_hz,
                                    width=None, height=None):
    for channel_index, channel_name in enumerate(['logmel', 'IF']):
        plt.clf()
        plt.close()
        fig, ax = full_frame(width=width, height=height)
        librosa.display.specshow(logmel_and_IF[channel_index].cpu().numpy(),
                                 sr=fs_hz, hop_length=hop_length,
                                 ax=ax)
        plt.savefig(str(filename_base.resolve()) + f'-{channel_name}.png')

WIDTH = 6
HEIGHT = 4
images_samples_dir = samples_dir / 'images'
images_samples_dir.mkdir(exist_ok=True)
for sample, reconstruction, note_str in zip(samples, reconstructions,
                                            annotations['note_str']):
    for audio, audio_name in zip([sample, reconstruction],
                                 ['original', 'reconstruction']):
        plt.clf()
        plt.close()
        filename_image_base = images_samples_dir / f'{note_str}-{audio_name}'
        write_spectrogram_image_to_file(audio, filename_image_base,
                                        spectrograms_helper.hop_length,
                                        spectrograms_helper.fs_hz,
                                        width=WIDTH, height=HEIGHT)

In [None]:
%%bash
tar -C paper -czf paper/reconstructions.tar.gz reconstructions

In [None]:
sample_index = 11000
single_sample_loader = dataloader_class(
    dataset=[nsynth_validation_dataset[sample_index]],
    spectrograms_helper=spectrograms_helper,
    batch_size=1,
    num_workers=8, shuffle=True,
    pin_memory=True)

sample_note_str = nsynth_validation_dataset[sample_index][1]['note_str']
acoustic_guitar_sample, acoustic_guitar_reconstruction, acoustic_guitar_id_top, acoustic_guitar_id_bottom, annotations = sample_reconstructions(single_sample_loader)
spectrograms_comparison_path = pathlib.Path('./paper/spectrograms_comparison/')
spectrograms_comparison_path.mkdir(exist_ok=True)

filename_reconstruction = spectrograms_comparison_path / 'audio' / f"{sample_note_str}-{vqvae_run_ID}_{checkpoint_epoch}-reconstruction.wav"
filename_reconstruction.parent.mkdir(parents=True, exist_ok=True)
acoustic_guitar_reconstruction_audio = make_audio(acoustic_guitar_reconstruction)
if NORMALIZE_AUDIO:
    acoustic_guitar_reconstruction_audio /= abs(acoustic_guitar_reconstruction_audio.max())
sf.write(str(filename_reconstruction), acoustic_guitar_reconstruction_audio,
         samplerate=spectrograms_helper.fs_hz, format=audio_format)

(spectrograms_comparison_path / 'images/').mkdir(exist_ok=True, parents=True)
base_filename = str((spectrograms_comparison_path / 'images' / f"{sample_note_str}-{vqvae_run_ID}_{checkpoint_epoch}").resolve())

write_spectrogram_image_to_file(
    acoustic_guitar_sample[0].detach(),
    pathlib.Path(base_filename + '-original'),
    spectrograms_helper.hop_length, spectrograms_helper.fs_hz)
write_spectrogram_image_to_file(
    acoustic_guitar_reconstruction[0].detach(),
    pathlib.Path(base_filename+ '-reconstruction'),
    spectrograms_helper.hop_length, spectrograms_helper.fs_hz)

heatmap(acoustic_guitar_id_top[0], cmap='magma')
plt.savefig(base_filename+ '-top.png')
heatmap(acoustic_guitar_id_bottom[0], cmap='magma')
plt.savefig(base_filename+ '-bottom.png')

In [None]:
%%bash
tar -C paper -cvf spectrograms_comparison.tar.gz spectrograms_comparison

In [None]:
# reconstruction  without using all hidden layers

with torch.no_grad():
    samples, _ = next(iter(single_sample_loader))
    quant_t, quant_b, *_ = vqvae.encode(samples.to(device))
    decoded_only_bottom = vqvae.decode(torch.zeros(quant_t.shape).to(device), quant_b).cpu()
    decoded_only_top = vqvae.decode(quant_t, torch.zeros(quant_b.shape).to(device)).cpu()
    decoded = vqvae.decode(quant_t, quant_b).cpu()

specs_and_IFs = [samples[0], decoded_only_bottom[0], decoded_only_top[0], decoded[0]]
plot_specs_and_IFs(specs_and_IFs)
# plt.suptitle('Original, Reconstructed using only bottom / only top / top and bottom', pad=20)

In [None]:
# display computed codes
named_code_layers = {'top': vqvae.quantize_t, 'bottom': vqvae.quantize_b}
num_layers = len(named_code_layers.values())
fig, _ = plt.subplots(num_layers, 2, figsize=(20, 5*num_layers))
from scipy.spatial.distance import pdist, squareform

for layer_index, (layer_name, code_layer) in enumerate(named_code_layers.items()):
    codes = code_layer.embed.data.cpu().T
    ax = plt.subplot(num_layers, 2, 1 + 2*layer_index)
    codes_matrix = codes - codes.mean(dim=0, keepdim=True)
    plt.matshow(codes_matrix, fignum=0)
    plt.title(f'{layer_name} layer codes (layer number {layer_index+1})', pad=20)
    ax = plt.subplot(num_layers, 2, 1 + 2*layer_index + 1)
#     codes_selfsimilarity = codes_matrix @ codes_matrix.T
    codes_correlation = np.corrcoef(codes_matrix)
    pairwise_correlations = pdist(codes_matrix + 1e-7, 'correlation')
    
#     print(pairwise_cosine_similarity)
#     plt.matshow(1 - abs(codes_correlation), fignum=0)
    plt.hist(pairwise_correlations, bins=30)
#     plt.colorbar()
    plt.title(f'{str(layer_name).capitalize()} layer codes absolute correlation \n (layer number {layer_index+1})', pad=20)

plt.tight_layout()
plt.show()

In [None]:
# code distribution over the training dataset
num_embeddings_top = vqvae.quantize_t.n_embed
num_embeddings_bottom = vqvae.quantize_b.n_embed
code_usage_top = torch.zeros(num_embeddings_top).long().to(device)
code_usage_bottom = torch.zeros(num_embeddings_bottom).long().to(device)

def count_uses(codes_batch, code_usage_matrix):
    num_codes = len(code_usage_matrix)
    code_usage_batch = codes_batch.flatten().bincount(minlength=num_codes)
    code_usage_matrix += code_usage_batch

with torch.no_grad():
    for batch, pitch in tqdm(loader):
        _, _, _, id_t, id_b, _, _ = vqvae.encode(batch.to(device))
        count_uses(id_t, code_usage_top)
        count_uses(id_b, code_usage_bottom)

plt.subplots(2, 1, figsize=(30, 15))
plt.subplot(2, 1, 1, title="Bottom layer")
plt.bar(range(num_embeddings_bottom), height=list(code_usage_bottom.cpu().numpy()))
plt.subplot(2, 1, 2, title="Top layer")
plt.bar(range(num_embeddings_top), height=list(code_usage_top.cpu().numpy()))
plt.savefig(f'code_usage-epoch_{checkpoint_epoch}.svg')

plt.suptitle("Code usage")
plt.show()

## Embeddings dimension comparison

In this section, we compare the dimension of the training/target data and of the embeddings,
in order to compute and check if the compression ratio actually makes sense.

In [None]:
one_shot_single_sample_iterator = iter(siptngle_sample_loader)
sample_as_batch, _ = next(one_shot_single_sample_iterator)
sample_as_batch
sample = sample_as_batch[0]

with torch.no_grad():
    *results_as_batches, _, _ = vqvae.encode(sample_as_batch.to(device))
    quant_t, quant_b, diff_t_plus_b, id_t, id_b = [
        tensor[0].cpu() for tensor in results_as_batches]

print('Original shape: ', sample.shape)
print('Quant bottom shape: ', quant_b.shape)
print('Quant top shape: ', quant_t.shape)
print('Code bottom shape: ', id_b.shape)
print('Code top shape: ', id_t.shape)
print("Compression ratio: ", (id_b.numel() + id_t.numel()) / sample.numel()) 

# Sound creation

In this section, we attempt to create new sounds by combining codes for heterogeneous sounds. 

In [None]:
two_sample_loader = WavToSpectrogramDataLoader(nsynth_dataset, batch_size=2,
                                               num_workers=2, shuffle=True)
one_shot_two_sample_iterator = iter(two_sample_loader)
with torch.no_grad():
    samples, _ = next(one_shot_two_sample_iterator)
    quant_t, quant_b, _, id_t, id_b, _, _ = vqvae.encode(samples.to(device))

print("Original audios")
make_audio_player(samples)

print('Reconstructions')
with torch.no_grad():
    reconstructions = vqvae.decode_code(id_t, id_b)
make_audio_player(reconstructions)

print('Exchange top and bottom codes')
with torch.no_grad():
    reconstructions_switched_codes = vqvae.decode_code(id_t, id_b.flip(0))
make_audio_player(reconstructions_switched_codes)

plot_specs_and_IFs(*samples[0:2], *reconstructions[0:2], *reconstructions_switched_codes[0:2])

# print('Exchange half of the codes (in the temporal dimension)')
# def switch_temporal_halves(codes_tensor: torch.Tensor) -> torch.Tensor:
#     batch_dim, frequency_dim, temporal_dim = 0, 1, 2
#     duration_n = codes_tensor.shape[temporal_dim]
#     first_half = codes_tensor[:, :, :duration_n//2]
#     second_half = codes_tensor[:, :, duration_n//2:]
#     return torch.cat([first_half.flip(0), second_half], temporal_dim)
# with torch.no_grad():
#     reconstructions_halves_switched = vqvae.decode_code(
#         *[switch_temporal_halves(codes) for codes in [id_t, id_b]])
# make_audio_player(reconstructions_halves_switched)

# print('Set half (temporally) of the latent map to zero')
# def zero_temporal_halves(codes_tensor: torch.Tensor) -> torch.Tensor:
#     batch_dim, frequency_dim, temporal_dim = 0, 1, 2
#     duration_n = codes_tensor.shape[temporal_dim]
#     first_half = codes_tensor[:, :, :duration_n//2]
#     zero_second_half = torch.zeros(codes_tensor[:, :, duration_n//2:].shape, dtype=first_half.dtype).to(device)
#     return torch.cat([first_half, zero_second_half], dim=temporal_dim)
# with torch.no_grad():
#     reconstructions_halves_zeroed = vqvae.decode_code(
#         *[zero_temporal_halves(codes) for codes in [id_t, id_b]])
# make_audio_player(reconstructions_halves_zeroed)

In [None]:
print('Linear interpolation of the embedddings')
def make_interpolations(start_map, end_map, steps):
    ratios = torch.linspace(0, 1, steps).to(device)
    translations = torch.einsum('i, jkl -> ijkl', ratios, (end_map - start_map))
    return start_map + translations

with torch.no_grad():
    num_steps = 5
    interpolations_unquant_t_and_b = [
        make_interpolations(codes[0], codes[1], num_steps)
        for codes in (quant_t, quant_b)]
    print(interpolations_unquant_t_and_b[0][:,0,0])
    print(interpolations_unquant_t_and_b[0].shape)
    print(interpolations_unquant_t_and_b[1].shape)
    
    interpolations_quant_t, interpolations_quant_b = [
        quantizer(interpolations_unquant.permute(0, 2, 3, 1))[0].permute(0, 3, 1, 2)
        for quantizer, interpolations_unquant in zip(
            [vqvae.quantize_t, vqvae.quantize_b],
            interpolations_unquant_t_and_b)
    ]
    
    print(interpolations_quant_t[:,0,0])
    interpolations = vqvae.decode(interpolations_quant_t, interpolations_quant_b)
for index, interpolation in enumerate(interpolations):
    print(f'Step {index}')
    make_audio_player(interpolation)

# comparing with the sum of the two signals in temporal domain
make_audio_player(reconstructions.sum(0)/2)
specs_and_IFs = [samples[0], samples[1], interpolations[1], interpolations[2], interpolations[3], samples.sum(0)]
plot_specs_and_IFs(*specs_and_IFs)

In [None]:
# WARNING: don't feed latent maps to quantizer by reshaping them, use .permute()
# Otherwise, the results are scrambled!
with torch.no_grad():
    embeddings_t = quant_t[0].reshape(128, 16, 64)
    requantized_embeddings_t = vqvae.quantize_t(embeddings_t)[0]
    display(requantized_embeddings_t - embeddings_t)

### Latent-codes distortion

Here, we slightly modify the latent maps to check the effect on the decoded audio

In [None]:
sample, reconstruction = inference_vqvae.sample_reconstructions(single_sample_loader)

with torch.no_grad():
    *results_as_batches, _, _ = vqvae.encode(sample.to(device))
    quant_t, quant_b, diff_t_plus_b, id_t, id_b = [
        tensor.cpu() for tensor in results_as_batches]

all_code_layers = {'top': id_t, 'bottom': id_b}
code_layers_to_corrupt_names = ['bottom']
code_layers_to_keep_names = list(set(all_code_layers.keys()) - set(code_layers_to_corrupt_names))
num_embeddings = {'top': vqvae.quantize_t.n_embed,
                  'bottom': vqvae.quantize_b.n_embed}

add_values = [+1, -1, +10, -10, +20, -20, +100, -100, +200, -200]
id_plus_corrupted = {
    code_layer_name: torch.cat([all_code_layers[code_layer_name]
                                .add(add_value)
                                .remainder(num_embeddings[code_layer_name])
                                .long()
     for add_value in add_values
    ], 0)
    for code_layer_name in code_layers_to_corrupt_names
}

id_plus_original = {
    code_layer_name: all_code_layers[code_layer_name].repeat(len(add_values), 1, 1)
    for code_layer_name in code_layers_to_keep_names
}

def pick_codes(original, corrupted):
    codes = [(original.get(code_layer_name, None) if original.get(code_layer_name, None) is not None
             else corrupted[code_layer_name]).to(device)
            for code_layer_name in all_code_layers.keys()]
    return codes

# weights = torch.Tensor([0.2, 0.6, 0.2])
# id_stochastic_1 = [
#     id_map.add(torch.multinomial(weights, id_map.numel(), replacement=True).reshape(id_map.shape).add(-1).to(device)).remainder(n_embed).long()
#     for id_map, n_embed in zip([id_t, id_b],
#                                [vqvae.quantize_t.n_embed, vqvae.quantize_b.n_embed])
# ]

# id_stochastic_randn_0_20_corrupted = [
#     id_map.add((torch.randn_like(id_map, dtype=float).to(device) * 20).long()).remainder(n_embed).long()
#     for id_map, n_embed in zip(codes_to_corrupt,
#                                num_embeddings_to_corrupt)
# ]

# print(id_stochastic_randn_0_20_corrupted)
# print(codes_to_keep)
# id_stochastic = id_stochastic_randn_0_20_corrupted + codes_to_keep
# print(id_stochastic)

with torch.no_grad():
    reconstructions_plus = vqvae.decode_code(*pick_codes(id_plus_corrupted, id_plus_original))
#     reconstructions_stochastic_1 = vqvae.decode_code(*id_stochastic_1)
#     reconstructions_stochastic_randn_0_20 = vqvae.decode_code(*id_stochastic)

print('Original')
make_audio_player(sample)

print('Reconstructions')
make_audio_player(reconstruction)
    
for index, add_value in enumerate(add_values):
    print('Uniform add', add_value)
    make_audio_player(reconstructions_plus[index])

# make_audio_player(reconstructions_stochastic_1)

# make_audio_player(reconstructions_stochastic_randn_0_20)

In [None]:
weights = torch.Tensor([0.2, 0.6, 0.2])
id_stochastic = [
    torch.multinomial(weights, id_map.numel(), replacement=True).reshape(id_map.shape).to(device).long()
    for id_map, n_embed in zip([id_t, id_b],
                               [vqvae.quantize_t.n_embed, vqvae.quantize_b.n_embed])
]

with torch.no_grad():
    reconstructions_stochastic = vqvae.decode_code(*id_stochastic)
make_audio_player(reconstructions_stochastic)

# Visualizing codemaps

In [None]:
# sample, reconstructions, id_top, id_bottom = sample_reconstructions()
import librosa.display

make_audio_player(sample[:1])
plt.clf()
plt.close()
# plots.plot_mel_representations(*[sample[0, i].cpu().numpy() for i in (0, 1)],
#                                hop_length=spectrograms_helper.hop_length, fs_hz=spectrograms_helper.fs_hz,
#                                 print_title=False)
fig = plt.figure(frameon=False)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
librosa.display.specshow(sample[0, 0].cpu().numpy(), sr=spectrograms_helper.fs_hz,
                         hop_length=spectrograms_helper.hop_length, ax=ax)
plt.savefig('mel_spec_amplitude.pdf')

fig = plt.figure(frameon=False)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax = librosa.display.specshow(sample[0, 1].cpu().numpy(), sr=spectrograms_helper.fs_hz,
                              hop_length=spectrograms_helper.hop_length, ax=ax)
plt.savefig('mel_IF.pdf')

make_spectrogram_image(sample[0,0], filename='upsampled_spectrogram',
                       upsampling_factor=4)


make_spectrogram_image(reconstructions[0,0], filename='upsampled_reconstructed_spectrogram',
                       upsampling_factor=4)


make_spectrogram_image(sample[0,0].flip(0), filename='flipped-upsampled_spectrogram',
                       upsampling_factor=4)
make_spectrogram_image(reconstructions[0,0].flip(0), filename='flipped-upsampled_reconstructed_spectrogram',
                       upsampling_factor=4)

fig = plt.figure(frameon=False)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
heatmap(id_top[0], "Top", cmap='magma')
heatmap(id_bottom[0], "Bottom", cmap='magma')

flattened_top = transformer_top.flatten_map(id_top[:1], kind='source')
heatmap(flattened_top, cmap='magma', height_in=1.5)
flattened_bottom = transformer_bottom.flatten_map(id_bottom[:1], kind='target')
heatmap(flattened_bottom, cmap='magma', height_in=1.5)


heatmap(id_top[0,:8,:1], cmap='magma')
plt.savefig('top_extract_1s_first_quarter_of_frequency_axis.pdf')
heatmap(id_bottom[0,:16,:2], cmap='magma')
plt.savefig('bottom_extract_1s_first_quarter_of_frequency_axis.pdf')

flattened_top = transformer_top.flatten_map(id_top[:1], kind='source')
print("Flattened top, sample")
heatmap(flattened_top[:,:8], cmap='magma', height_in=1.5)
plt.savefig('top_flattened_extract_1s_first_quarter_of_frequency_axis.pdf')
print("Flattened bottom, sample")
flattened_bottom = transformer_bottom.flatten_map(id_bottom[:1], kind='target')
heatmap(flattened_bottom[:,:32], cmap='magma', height_in=1.5)
plt.savefig('bottom_flattened_extract_1s_first_quarter_of_frequency_axis.pdf')


In [None]:
flattened_top_edited_fragment = flattened_top[:,:8].clone()
flattened_top_edited_fragment[:,-3] = (flattened_top_edited_fragment[:,-3] + 666) % 512 
heatmap(flattened_top[:,:8], cmap='magma', height_in=1.5)
heatmap(flattened_top_edited_fragment, cmap='magma', height_in=1.5)
plt.savefig('edited-top_flattened_extract_1s_first_quarter_of_frequency_axis.pdf')
print("Flattened bottom, sample")
flattened_bottom_edited_fragment = flattened_bottom[:,:32].clone()
flattened_bottom_edited_fragment[:,-4*3:-4*2] = ((4 + flattened_bottom_edited_fragment[:,-4*3:-4*2]) ** 3) % 512
heatmap(flattened_bottom[:,:32], cmap='magma', height_in=1.5)
heatmap(flattened_bottom_edited_fragment, cmap='magma', height_in=1.5)
plt.savefig('edited-bottom_flattened_extract_1s_first_quarter_of_frequency_axis.pdf')

In [None]:
flattened_top_edited = flattened_top.clone()
flattened_top_edited[:,:8] = flattened_top_edited_fragment
flattened_bottom_edited = flattened_bottom.clone()
flattened_bottom_edited[:,:32] = flattened_bottom_edited_fragment

id_top_edited = transformer_top.to_time_frequency_map(flattened_top_edited, kind='target')
id_bottom_edited = transformer_bottom.to_time_frequency_map(flattened_bottom_edited, kind='target')

heatmap(id_top_edited[0,:8,:1], cmap='magma')
plt.savefig('edited-top_extract_1s_first_quarter_of_frequency_axis.pdf')
heatmap(id_bottom_edited[0,:16,:2], cmap='magma')
plt.savefig('edited-bottom_extract_1s_first_quarter_of_frequency_axis.pdf')

with torch.no_grad():
    reconstructed = vqvae.decode_code(id_top_edited, id_bottom_edited)

make_spectrogram_image(reconstructed[0,0], filename='upsampled_edited_spectrogram',
                       upsampling_factor=4)
make_spectrogram_image(reconstructed[0,0].flip(0), filename='flipped-upsampled_edited_spectrogram',
                       upsampling_factor=4)

In [None]:
from sample import time_stretch_and_resample

time_stretch_and_resample()

# Sound generation

In [None]:
transformers_root_path = pathlib.Path(f'./checkpoints/code_prediction/vqvae-{vqvae_run_ID}/')

transformer_top_run_id = 'Transformer-top_layer-20200513-231538-0bd9f5' 
transformer_top_folder = transformers_root_path / transformer_top_run_id
transformer_top_weights_path = transformer_top_folder / 'Transformer-layer_top.pt'
transformer_top_model_parameters_path = transformer_top_folder / 'model_instantiation_parameters.json'

transformer_bottom_run_id = 'Transformer-bottom_layer-20200512-165540-9964e5' 
transformer_bottom_folder = transformers_root_path / transformer_bottom_run_id
transformer_bottom_weights_path = transformer_bottom_folder / 'Transformer-layer_bottom.pt'
transformer_bottom_model_parameters_path = transformer_bottom_folder / 'model_instantiation_parameters.json'

transformer_top = VQNSynthTransformer.from_parameters_and_weights(
    transformer_top_model_parameters_path,
    transformer_top_weights_path,
    device=DEVICE
)
transformer_top = transformer_top.eval()

transformer_bottom = VQNSynthTransformer.from_parameters_and_weights(
    transformer_bottom_model_parameters_path,
    transformer_bottom_weights_path,
    device=DEVICE
)
transformer_bottom = transformer_bottom.eval()

In [None]:
import numpy as np
from sample import sample_model, make_conditioning_tensors
from dataset import LMDBDataset
import soundfile as sf

classes_for_conditioning = ['pitch', 'instrument_family_str']
DATABASE_PATH = pathlib.Path('./codes/vqvae-20200309-220303-d006ab-weights-vqvae_nsynth_436/train/').resolve()
dataset = LMDBDataset(
    DATABASE_PATH,
    classes_for_conditioning=list(classes_for_conditioning)
)
label_encoders_per_conditioning = dataset.label_encoders

all_instrument_labels = (
    label_encoders_per_conditioning['instrument_family_str']
    .classes_
    .tolist())

NORMALIZE_AUDIO = True

## Unconditioned sampling

In [None]:
unconditional_sampling_pitches = np.arange(24, 85, 7)
unconditional_sampling_instruments = all_instrument_labels
num_pitches = len(unconditional_sampling_pitches)
num_instruments = len(unconditional_sampling_instruments)
num_samples = num_pitches * num_instruments
print(f"Will generate a total of {num_samples} samples, "
      f"from {num_instruments} different instruments")

DEVICE = 'cuda'
class_conditioning_tensors = {}

encoded_pitches = label_encoders_per_conditioning['pitch'].transform(
    unconditional_sampling_pitches).transpose()
class_conditioning_tensors['pitch'] = (
    torch.from_numpy(encoded_pitches).long()
    .to(DEVICE))

unconditional_generation_path = pathlib.Path(f'./paper/unconditional_generation/{vqvae_run_ID}_{checkpoint_epoch}')
unconditional_generation_path.mkdir(parents=True, exist_ok=True)

for instrument in ['bass']:
    encoded_instrument = label_encoders_per_conditioning['instrument_family_str'].transform(
        [instrument]).transpose()
    class_conditioning_tensors['instrument_family_str'] = (
        torch.from_numpy(encoded_instrument).long().repeat(num_pitches)
        .to(DEVICE))

    with torch.no_grad():
        transformer_top.to(DEVICE)
        transformer_top.eval()
        sampled_top = sample_model(
            transformer_top, device=DEVICE, batch_size=num_pitches,
            class_conditioning=class_conditioning_tensors,
            codemap_size=transformer_top.shape,
            temperature=1.0, top_p_sampling_p=0.8,
            use_multi_gpus=False)
        
        transformer_bottom.to(DEVICE)
        transformer_bottom.eval()
        sampled_bottom = sample_model(
            transformer_bottom, device=DEVICE, batch_size=num_pitches,
            condition=sampled_top,
            class_conditioning=class_conditioning_tensors,
            codemap_size=transformer_bottom.shape,
            temperature=1.0, top_p_sampling_p=0.8,
            use_multi_gpus=False)

        vqvae.to(DEVICE)
        vqvae.eval()
        samples_batch = vqvae.decode_code(sampled_top, sampled_bottom)
        for pitch, sample in zip(unconditional_sampling_pitches, samples_batch):
            filename = unconditional_generation_path / f'{instrument}-{pitch}.wav'
            audio = spectrograms_helper.to_audio(sample.unsqueeze(0)).cpu().numpy()
            if NORMALIZE_AUDIO:
                sample /= abs(sample.max())
            sf.write(str(filename), audio[0],
                     samplerate=spectrograms_helper.fs_hz, format='WAV')

In [None]:
sample.shape

In [None]:
transformer_bottom.local_class_conditioning

In [None]:
make_audio_player(audio_batch)

In [None]:
for pitch, sample in zip(unconditional_sampling_pitches, samples_batch):
    filename = unconditional_generation_path / f'{instrument}-{pitch}.wav'
    audio = spectrograms_helper.to_audio(sample.unsqueeze(0)).cpu().numpy()
    if NORMALIZE_AUDIO:
        sample /= abs(sample.max())
    sf.write(str(filename), audio[0],
             samplerate=spectrograms_helper.fs_hz, format='WAV')

In [None]:
%%bash
tar -C ./paper -cf ./paper/unconditional_generation.tar.gz unconditional_generation/20200309-220303-d006ab_436/

In [None]:
', '.join([str(i) for i in unconditional_sampling_pitches])

## Inpainting operations

In [None]:
# id_bottom_sample_sequence = torch.LongTensor([127,504,30,501,419,175,292,283,43,43,153,159,189,203,266,510,427,427,490,488,183,509,263,401,442,473,393,358,219,15,435,325,103,210,102,478,326,485,319,52,167,121,250,10,327,228,155,167,373,195,30,401,80,220,342,211,283,494,427,478,492,252,429,257,366,148,494,429,511,361,472,83,61,494,467,157,385,420,427,78,78,69,326,326,385,327,121,220,429,157,69,326,275,126,453,385,157,492,77,54,385,157,77,121,120,327,78,398,126,134,134,492,80,396,134,492,134,396,385,134,327,385,472,61,80,460,54,246,446,74,490,306,12,444,153,86,196,439,283,490,359,478,30,502,501,509,483,366,294,60,387,270,330,325,393,260,364,266,196,501,319,127,359,450,127,215,510,9,168,107,208,266,102,427,431,378,129,36,203,409,131,54,309,54,154,110,221,395,279,89,49,220,1,210,503,259,175,494,49,157,145,77,103,155,427,77,221,472,167,413,29,138,83,78,59,157,412,121,494,429,494,220,61,355,157,413,251,350,121,412,398,54,49,220,78,121,472,54,398,78,54,54,121,78,78,49,157,49,157,398,78,281,326,188,367,396,53,12,415,30,159,444,437,164,153,102,294,325,43,219,358,110,153,415,308,173,450,260,166,338,330,298,46,258,319,127,364,478,292,455,24,153,349,313,108,264,238,298,334,319,323,445,237,86,161,108,83,86,389,49,121,435,454,355,61,326,123,412,309,123,221,350,155,453,253,494,83,83,361,398,175,78,175,78,412,198,49,385,77,54,121,157,398,54,78,413,385,413,69,54,361,78,54,454,412,253,69,472,157,121,138,54,121,121,121,54,78,121,429,54,78,78,413,78,361,77,54,115,412,367,412,80,237,249,509,108,162,472,432,359,260,183,490,325,59,175,450,502,103,43,159,478,106,497,125,56,421,156,330,52,248,258,292,159,364,355,485,325,33,131,485,67,309,457,448,211,59,131,33,146,13,78,78,183,326,111,193,189,148,385,342,309,492,195,58,89,385,429,54,119,89,453,472,253,89,80,350,494,80,120,54,121,429,396,420,362,89,318,295,220,420,126,327,420,385,472,385,220,80,89,138,492,492,351,134,140,318,89,420,472,388,276,416,492,126,117,276,472,134,134,472,385,492,89,121,59,429,385,423,157,140,126])
# id_top_sample_sequence = torch.LongTensor([198,339,141,423,408,198,1,415,140,173,307,430,198,341,430,198,29,106,323,251,25,364,323,479,286,295,395,385,389,389,358,78,389,468,340,482,482,425,35,454,140,262,358,445,382,150,134,489,511,75,134,389,25,451,207,325,489,134,437,221,451,389,389,295,348,256,340,281,50,425,165,140,22,134,265,429,483,291,74,90,33,75,72,66,429,197,1,469,134,505,323,207,129,451,325,139,90,266,319,92,468,92,104,266,197,459,90,90,468,224,92,339,323,90,252,207,21,257,198,92,134,124,339,207,129,437,451,134])
id_bottom_sample_sequence = torch.LongTensor([327,361,361,184,385,253,423,431,412,494,478,358,437,435,213,118,189,454,319,319,433,148,494,229,94,477,506,179,386,399,427,134,479,413,42,463,465,193,61,4,301,140,208,419,89,80,427,211,429,350,78,281,200,351,321,115,344,126,295,429,50,413,413,319,321,158,54,331,14,460,298,472,190,328,228,138,433,40,65,257,321,304,155,246,178,178,278,7,252,178,318,374,269,269,481,481,178,269,481,481,178,178,481,481,178,269,21,481,178,178,481,481,178,269,481,481,269,178,481,481,62,269,353,174,178,477,353,50,298,79,407,446,478,319,229,281,365,441,55,348,101,417,75,208,510,30,30,435,78,10,175,71,441,49,348,431,435,386,115,350,108,355,509,427,80,341,123,306,267,446,164,15,77,350,237,251,407,511,211,366,484,447,162,8,420,39,362,83,221,501,413,366,412,407,220,91,435,492,455,220,3,385,3,416,251,27,145,257,13,386,106,396,418,353,7,418,120,7,138,353,7,7,353,353,353,328,353,353,7,318,7,318,460,7,128,353,353,353,353,353,353,353,353,353,353,7,353,353,481,126,460,188,340,126,278,134,279,154,15,427,230,358,154,263,497,30,364,60,213,334,272,44,319,298,478,43,472,106,103,107,375,431,347,287,430,327,29,295,108,455,371,111,253,204,120,341,334,511,506,367,350,103,15,407,175,161,154,367,290,103,305,435,148,61,220,412,251,367,220,30,195,210,29,35,167,472,29,472,508,429,3,385,49,140,251,290,419,50,337,433,7,418,418,418,429,353,429,353,353,353,353,353,353,353,353,353,353,388,7,140,328,481,433,7,353,353,353,353,353,353,353,353,353,353,353,353,460,188,460,386,278,140,418,492,198,358,309,447,490,53,183,405,450,432,447,179,101,112,289,179,444,53,475,229,161,319,242,193,260,30,63,447,145,494,29,492,441,367,237,191,249,74,115,503,313,65,179,447,281,198,419,155,43,237,106,83,181,103,220,295,138,220,138,472,77,501,49,83,389,487,398,398,110,138,78,134,128,420,396,492,251,162,39,246,189,477,472,327,418,252,318,50,50,418,126,126,328,328,126,126,328,418,126,134,252,481,134,386,252,418,126,120,418,418,134,134,418,418,134,134,418,418,134,134,418,481,134,134,81,117,120,157])
id_top_sample_sequence = torch.LongTensor([74,489,273,301,411,382,319,394,267,307,348,41,507,54,351,92,122,347,453,129,348,414,453,221,391,325,198,221,221,221,483,97,207,90,340,262,511,270,408,219,394,511,282,482,245,133,453,90,347,142,90,23,25,92,339,221,221,463,334,414,221,221,449,389,174,334,163,489,43,469,408,99,22,213,469,169,266,133,21,134,295,25,90,1,35,451,339,295,389,301,45,295,295,221,21,21,449,165,378,199,165,193,333,259,482,86,501,339,482,137,476,124,198,445,449,137,391,505,437,389,221,92,45,295,391,221,97,21])
sample_name = "organ_electronic_028-048-050.wav"  # from the NSynth test split

id_bottom_sample_codemap = transformer_bottom.to_time_frequency_map(
    id_bottom_sample_sequence.unsqueeze(0),
    kind='target')
id_top_sample_codemap = transformer_top.to_time_frequency_map(
    id_top_sample_sequence.unsqueeze(0),
    kind='target')
sample_original_pitch = 48

def upsample_mask(mask_top):
    mask_bottom = mask_top.clone()
    
    for repeat_dim in [1, 2]:
        repeats = (transformer_bottom.shape[repeat_dim-1]
                   // transformer_top.shape[repeat_dim-1])
        mask_bottom = torch.repeat_interleave(
            mask_bottom, repeats=repeats, dim=repeat_dim)
    return mask_bottom

two_seconds_all_frequencies_mask_top = torch.full_like(id_top_sample_codemap,
                                                       fill_value=False, dtype=bool)
two_seconds_all_frequencies_mask_top[:, :, :2] = True

last_three_seconds_all_frequencies_mask_top = torch.full_like(id_top_sample_codemap,
                                                              fill_value=False, dtype=bool)
last_three_seconds_all_frequencies_mask_top[:, :, 1:] = True

all_duration_bottom_half_frequencies_mask_top = torch.full_like(id_top_sample_codemap,
                                                                fill_value=False, dtype=bool)
all_duration_bottom_half_frequencies_mask_top[:, :transformer_top.shape[0]//2, :] = True

all_duration_bottom_quarter_frequencies_mask_top = torch.full_like(id_top_sample_codemap,
                                                                fill_value=False, dtype=bool)
all_duration_bottom_quarter_frequencies_mask_top[:, :transformer_top.shape[0]//4, :] = True

named_masks = {
#     'two_seconds_all_frequencies': {
#         'top': two_seconds_all_frequencies_mask_top,
#         'bottom': upsample_mask(two_seconds_all_frequencies_mask_top)
#     },
#     'last_three_seconds_all_frequencies': {
#         'top': last_three_seconds_all_frequencies_mask_top,
#         'bottom': upsample_mask(last_three_seconds_all_frequencies_mask_top)
#     },
#     'all_duration_bottom_half_frequencies': {
#         'top': all_duration_bottom_half_frequencies_mask_top,
#         'bottom': upsample_mask(all_duration_bottom_half_frequencies_mask_top)
#     },
    'all_duration_bottom_quarter_frequencies': {
        'top': all_duration_bottom_quarter_frequencies_mask_top,
        'bottom': upsample_mask(all_duration_bottom_quarter_frequencies_mask_top)
    },
#     'keep_top_resample_full_bottom': {
#         'top': torch.full_like(id_top_sample_codemap, fill_value=False, dtype=bool),
#         'bottom': torch.full_like(id_bottom_sample_codemap, fill_value=True, dtype=bool),
#     },
#     'resample_full_top_keep_bottom': {
#         'top': torch.full_like(id_top_sample_codemap, fill_value=True, dtype=bool),
#         'bottom': torch.full_like(id_bottom_sample_codemap, fill_value=False, dtype=bool),
#     }
}

inpainting_pitches = np.array([24, 36, 43, 48, 60, 64, 67, 72])
inpainting_instruments = all_instrument_labels
num_pitches_inpainting = len(inpainting_pitches)

DEVICE = 'cuda'
TEMPERATURE = 1
TOP_P = 0.8
num_samples_per_class = 4
class_conditioning_tensors_inpainting = {}

encoded_inpainting_pitches = label_encoders_per_conditioning['pitch'].transform(
    inpainting_pitches).transpose()
class_conditioning_tensors_inpainting['pitch'] = (
    torch.from_numpy(encoded_inpainting_pitches).long()
    .repeat_interleave(num_samples_per_class)
    .to(DEVICE))

masked_inpainting_path = pathlib.Path(f'./paper/masked_inpainting/{vqvae_run_ID}_{checkpoint_epoch}/temp_{TEMPERATURE}-top_p_{TOP_P}')
masked_inpainting_path.mkdir(parents=True, exist_ok=True)

for mask_name, masks in named_masks.items():
    save_dir_path = masked_inpainting_path / mask_name
    save_dir_path.mkdir(parents=True, exist_ok=True)
    
    print(f"Mask: {mask_name}")
    for instrument in inpainting_instruments:
        print(f"Instrument: {instrument}")
        encoded_instrument = label_encoders_per_conditioning['instrument_family_str'].transform(
            [instrument]).transpose()
        class_conditioning_tensors_inpainting['instrument_family_str'] = (
            torch.from_numpy(encoded_instrument).long()
            .repeat(num_pitches_inpainting)
            .repeat_interleave(num_samples_per_class)
            .to(DEVICE))

        with torch.no_grad():
            transformer_top.to(DEVICE)
            transformer_top.eval()
            inpainted_top = sample_model(
                transformer_top, device=DEVICE,
                batch_size=num_pitches_inpainting*num_samples_per_class,
                class_conditioning=class_conditioning_tensors_inpainting,
                initial_code=id_top_sample_codemap.to(DEVICE).repeat((num_pitches_inpainting*num_samples_per_class, 1, 1)),
                codemap_size=transformer_top.shape,
                temperature=TEMPERATURE,
                top_p_sampling_p=TOP_P,
                use_multi_gpus=True,
                mask=masks['top'].to(DEVICE),
                progressbar_decorator=tqdm_notebook
            )
        
            transformer_bottom.to(DEVICE)
            transformer_bottom.eval()
            inpainted_bottom = sample_model(
                transformer_bottom,
                device=DEVICE,
                batch_size=num_pitches_inpainting*num_samples_per_class,
                condition=inpainted_top,
                initial_code=id_bottom_sample_codemap.to(DEVICE).repeat((num_pitches_inpainting*num_samples_per_class, 1, 1)),
                class_conditioning=class_conditioning_tensors_inpainting,
                codemap_size=transformer_bottom.shape,
                temperature=TEMPERATURE,
                top_p_sampling_p=TOP_P,
                use_multi_gpus=True,
                mask=masks['bottom'].to(DEVICE),
                progressbar_decorator=tqdm_notebook
            )

            vqvae.to(DEVICE)
            vqvae.eval()
            samples_batch = vqvae.decode_code(inpainted_top, inpainted_bottom)
            for pitch, sample in zip(inpainting_pitches, samples_batch.split(num_samples_per_class)):
                filename = save_dir_path / f'{instrument}-{pitch}.wav'
                audio = make_audio(sample)
                if NORMALIZE_AUDIO:
                    sample /= abs(sample.max())
                sf.write(str(filename), audio,
                         samplerate=spectrograms_helper.fs_hz, format='WAV')

In [None]:
with torch.no_grad():
    decoded_sample = vqvae.decode_code(id_top_sample_codemap.to(DEVICE),
                                       id_bottom_sample_codemap.to(DEVICE))
    print(decoded_sample.shape)
    make_audio_player(decoded_sample)

In [None]:
%%bash
tar -C paper/masked_inpainting/20200309-220303-d006ab_436/temp_1-top_p_0.8/ -cf ./paper/masked_inpainting.tar.gz all_duration_bottom_quarter_frequencies/ keep_top_resample_full_bottom/ two_seconds_all_frequencies/

In [None]:
top_code, bottom_code = vqvae.forward(spectrograms_helper.to_spectrogram(torch.zeros(1, 100000).to(DEVICE)))[-2:]
print(top_code.shape)
print(bottom_code.shape)

In [None]:
vqvae.resolution_factors