# Level 4 sequence generation

In [None]:
import IPython
import copy

import torch
from torch.utils.data import DataLoader, TensorDataset
import yaml
import matplotlib.pyplot as plt
import torchaudio

from models.multi_level_vqvae import MultiLvlVQVariationalAutoEncoder
from models.infected_lm import TransformerAutoregressor
from loaders.latent_loaders import Lvl5InputDataset
from utils.other import load_cfg_dict

### Load configuration files and weights

In [None]:
config_path_lvl1 = "config/lvl1_config.yaml"
weights_path_lvl1 = "model_weights/lvl1_vqvae.ckpt"
cfg_1 = load_cfg_dict(config_path_lvl1)

config_path_lvl2 = "config/lvl2_config.yaml"
weights_path_lvl2 = "model_weights/lvl2_vqvae.ckpt"
cfg_2 = load_cfg_dict(config_path_lvl2)

config_path_lvl3 = "config/lvl3_config.yaml"
weights_path_lvl3 = "model_weights/lvl3_vqvae.ckpt"
cfg_3 = load_cfg_dict(config_path_lvl3)

config_path_lvl4 = "config/lvl4_config.yaml"
weights_path_lvl4 = "model_weights/lvl4_vqvae.ckpt"
cfg_4 = load_cfg_dict(config_path_lvl4)

config_path_diff_lvl5 = "config/diff_lvl4_config.yaml"
cfg_diff_5 = load_cfg_dict(config_path_diff_lvl5)

config_path_seq_lvl4 = "config/seq_lvl4_config.yaml"
weights_path_lvl4_seq = "model_weights/lvl4_seq.ckpt"
cfg_seq_4 = load_cfg_dict(config_path_seq_lvl4)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load the models

In [None]:
model_lvl1 = MultiLvlVQVariationalAutoEncoder.load_from_checkpoint(weights_path_lvl1, **cfg_1, strict=False).to(device)
model_lvl1.eval()

model_lvl2 = MultiLvlVQVariationalAutoEncoder.load_from_checkpoint(weights_path_lvl2, **cfg_2, strict=False).to(device)
model_lvl2.eval()

model_lvl3 = MultiLvlVQVariationalAutoEncoder.load_from_checkpoint(weights_path_lvl3, **cfg_3, strict=False).to(device)
model_lvl3.eval()

model_lvl4 = MultiLvlVQVariationalAutoEncoder.load_from_checkpoint(weights_path_lvl4, **cfg_4, strict=False).to(device)
model_lvl4.eval()

model_seq_lvl4 = TransformerAutoregressor.load_from_checkpoint(weights_path_lvl4_seq, **cfg_seq_4, 
                                                               codebook = model_lvl4.vq_module.vq_codebook, strict=False).to(device)
model_seq_lvl4.eval()

### HTML wrapper to display sound clips

In [None]:
# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
               of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
      <source src="files/%s" type="audio/mp3" />
      Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

# Load a dataset and display it

### Load the slice

In [None]:
dataset = Lvl5InputDataset(preload=True, **cfg_diff_5)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
for sample in dataloader:
    lvl4_latent = sample['music slice'].squeeze(0)
    lvl4_latent_prev = sample["back conditional slice"].squeeze(0)
    print(f"Current track is {sample['track name'][0]}")
    break

plt.figure(figsize=(25, 5))
plt.matshow(lvl4_latent.cpu().detach().numpy(), fignum=1, aspect='auto', vmin=-0.5, vmax=2.0)
plt.colorbar()
plt.axis('off')
plt.show()

plt.figure(figsize=(25, 5))
plt.matshow(lvl4_latent_prev.cpu().detach().numpy(), fignum=1, aspect='auto', vmin=-0.5, vmax=2.0)
plt.colorbar()
plt.axis('off')
plt.show()

## CREATE MUSIC!!!

In [None]:
batch_divide = 8
batch_size = 4

print(lvl4_latent.size())

with torch.no_grad():
    lvl4_latent_predicted = model_seq_lvl4.generate_sequence(lvl4_latent, prev_slice=lvl4_latent_prev, temperature=1)

with torch.no_grad():
    lvl3_pred, _ = model_lvl4.decode(lvl4_latent_predicted.unsqueeze(0).to(device))

with torch.no_grad():
    
    lvl4_sample_divided = lvl3_pred.permute((0, 2, 1)).reshape((batch_divide, -1, cfg_3['latent_depth'])).permute((0, 2, 1))
    print(lvl4_sample_divided.size())
    
    
    plt.figure(figsize=(25, 5))
    plt.matshow(lvl3_pred[:, :].squeeze(0).cpu().detach().numpy(), fignum=1, aspect='auto', vmin=-2, vmax=2.0)
    plt.colorbar()
    plt.axis('off')
    plt.show()
    
    lvl4_dataset = TensorDataset(lvl4_sample_divided)
    lvl4_dataloader = DataLoader(lvl4_dataset, batch_size=batch_size)
    
    output_lvl1 = torch.zeros([1, cfg_1['input_channels'], 0]).to(device=device)
    
    for lvl4_ind_sample in lvl4_dataloader:
        
        output_lvl3, _ = model_lvl3.decode(lvl4_ind_sample[0].to(device))
        output_lvl3 = output_lvl3.permute((0, 2, 1)).reshape((batch_divide * batch_size, -1, cfg_2['latent_depth'])).permute((0, 2, 1))
        
        lvl3_dataset = TensorDataset(output_lvl3)
        lvl3_dataloader = DataLoader(lvl3_dataset, batch_size=batch_size)
        
        for lvl3_ind_sample in lvl3_dataloader:

            output_lvl2, _ = model_lvl2.decode(lvl3_ind_sample[0].to(device))
            output_lvl2 = output_lvl2.permute((0, 2, 1)).reshape((batch_divide * batch_size, -1, cfg_1['latent_depth'])).permute((0, 2, 1))
            
            lvl2_dataset = TensorDataset(output_lvl2)
            lvl2_dataloader = DataLoader(lvl2_dataset, batch_size=batch_size)
            
            for lvl2_ind_sample in lvl2_dataloader:
                
                output_lvl1_ind, _ = model_lvl1.decode(lvl2_ind_sample[0].to(device))
                output_lvl1_ind = output_lvl1_ind.permute((0, 2, 1)).reshape((1, -1, 1)).permute((0, 2, 1))
                output_lvl1 = torch.cat((output_lvl1, output_lvl1_ind), dim=2)
    
music_sample_rec = output_lvl1.view((1, -1))
plt.figure(figsize=(25, 5))
plt.plot(music_sample_rec[0, ...].cpu().detach().numpy())
plt.ylim((-1.1, 1.1))
plt.show()
torchaudio.save('sample_out.mp3', music_sample_rec.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")