# Level 2 VQ-VAE Run Script

## Import a bunch of libraries

In [None]:
import IPython

import torch
from torch.utils.data import DataLoader
import yaml
import matplotlib.pyplot as plt
import torchaudio

from models.multi_level_vqvae import MultiLvlVQVariationalAutoEncoder
from loaders.music_loader import MP3SliceDataset
from loaders.lvl2_loader import Lvl2InputDataset
from utils.other import load_cfg_dict

## Load configuration file and add weights path

In [None]:
config_path_lvl1 = "config/lvl1_config.yaml"
weights_path_lvl1 = "model_weights/lvl1_vqvae.ckpt"
#weights_path = "model_best.ckpt"
cfg_1 = load_cfg_dict(config_path_lvl1)

config_path_lvl2 = "config/lvl2_config.yaml"
weights_path_lvl2 = "model_weights/lvl2_vqvae.ckpt"
#weights_path = "model_best.ckpt"
cfg_2 = load_cfg_dict(config_path_lvl2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define Mel Spectrogram

In [None]:
mel_spec_new = torchaudio.transforms.MelSpectrogram(n_fft=512, hop_length=128, n_mels=128, 
                                                    pad_mode='reflect', power=1.0, norm= "slaney", mel_scale= "htk").to(device)

## Load the models

In [None]:
model_lvl1 = MultiLvlVQVariationalAutoEncoder(**cfg_1).to(device)
model_lvl1 = model_lvl1.load_from_checkpoint(weights_path_lvl1, **cfg_1, strict=False).to(device)
model_lvl1.eval()

model_lvl2 = MultiLvlVQVariationalAutoEncoder(**cfg_2).to(device)
#model_lvl2 = model_lvl1.load_from_checkpoint(weights_path_lvl2, **cfg_2, strict=False).to(device)
model_lvl2.eval()

### HTML wrapper to display sound clips:

In [None]:
# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
               of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
      <source src="files/%s" type="audio/mp3" />
      Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

## Load a dataset sample and display it.

### Create a mel weight matrix

In [None]:
lin_vector = torch.linspace(0.1, 5, 128)
eye_mat = torch.diag(lin_vector).to(device)

### Load the slice

In [None]:
dataset = Lvl2InputDataset(preload=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
for sample in dataloader:
    lvl2_sample = sample['music slice'].squeeze(0)
    print(f"Current track is {sample['track name'][0]}")
    break
plt.figure(figsize=(25, 5))
plt.matshow(lvl2_sample.cpu().detach().numpy(), fignum=1, aspect='auto', vmin=-0.5, vmax=2.0)
plt.colorbar()
plt.axis('off')
plt.show()

## Run the lvl 2 network and show results

In [None]:
lvl2_output = model_lvl2(lvl2_sample.clone().unsqueeze(0).to(device), extract_losses=False)
with torch.no_grad():
    lvl1_pred = lvl2_output['output'][0]
plt.figure(figsize=(25, 5))
plt.matshow(lvl1_pred.cpu().detach().numpy(), fignum=1, aspect='auto', vmin=-0.5, vmax=2.0)
plt.colorbar()
plt.axis('off')
plt.show()
print(lvl1_pred.size(), lvl2_sample.size())

## Show the subtraction between the ground truth and the reconstruction

In [None]:
plt.figure(figsize=(25, 5))
plt.matshow(torch.abs(lvl2_sample.cpu() - lvl1_pred.cpu()).detach().numpy(), fignum=1, aspect='auto', vmin=0, vmax=2.0)
plt.colorbar()
plt.axis('off')
plt.show()

## Pass the level 2 outputs through the level 1 decoder and make MUSIC!

### Sanity Check, reconstruction of the lvl1 data

In [None]:
# vq_input = torch.zeros((8, 8, 512)).to(device)
# for idx in range(8):
#     vq_input[idx, :, :] = lvl2_sample[:, idx * 512: (idx + 1) * 512]

with torch.no_grad():
    # z_q_out = vq_input
    z_q_out = model_lvl1.vq_module(lvl2_sample.unsqueeze(0).to(device))['v_q']
    output = model_lvl1.decoder(z_q_out)
    
music_sample_rec = output.view((1, -1))
plt.figure(figsize=(25, 5))
plt.plot(music_sample_rec[0, ...].cpu().detach().numpy())
plt.ylim((-1.1, 1.1))
plt.show()
torchaudio.save('sample_out.mp3', music_sample_rec.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")

### CREATE *MUSIC*!!!

In [None]:
with torch.no_grad():
    # z_q_out = vq_input
    z_q_out = model_lvl1.vq_module(lvl1_pred.unsqueeze(0).to(device))['v_q']
    output = model_lvl1.decoder(z_q_out)
    
music_sample_rec = output.view((1, -1))
plt.figure(figsize=(25, 5))
plt.plot(music_sample_rec[0, ...].cpu().detach().numpy())
plt.ylim((-1.1, 1.1))
plt.show()
torchaudio.save('sample_out.mp3', music_sample_rec.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")