# Diffusion LLM Music Testing Notebook

### Change working directory to this root

In [None]:
import os
os.chdir("..")
print(os.getcwd())

### Imports

In [None]:
import random

import torch
import hydra
import matplotlib.pyplot as plt
import torchaudio
import IPython
import hydra

from models.modules.music import MusicLightningModule
from models.modules.diffusion_llm import DiffusionLLMLightningModule
from utils.config import load_cfg_from_hydra
from utils.transform_func import log_normal
from utils.containers import MelSpecParameters
from models.mel_spec_converters import SimpleMelSpecConverter
from models.modules.base import load_inner_model_state_dict

### Load DLLM

In [None]:
weights_path = "weights/lvl1_dllm.ckpt"

device = "cpu"
print(f"Current device: {device}")

cfg_dllm = load_cfg_from_hydra(config_path="../config", config_name="lvl1_dllm", overrides=["data=lvl1_vqvae"])
cfg_dllm.learning.batch_size = 1
cfg_dllm.learning.val_split = 1.0

model_dllm = hydra.utils.instantiate(cfg_dllm.module, _convert_="partial").to(device)
model_dllm: DiffusionLLMLightningModule = load_inner_model_state_dict(model_dllm, weights_path).to(device)

### Load Tokenizer

In [None]:


weights_path = "trained/lvl1_vqvae/model.ckpt"

cfg_tokenizer = load_cfg_from_hydra(config_path="../trained/lvl1_vqvae", config_name="config") 
cfg_tokenizer.learning.batch_size = 128
cfg_tokenizer.learning.val_split = 1.0

model_tokenizer = hydra.utils.instantiate(cfg_tokenizer.module, _convert_="partial").to(device)
model_tokenizer: MusicLightningModule = load_inner_model_state_dict(model_tokenizer, weights_path).to(device)

### Initialize Mel Spec Parameters

In [None]:
n_mels = 64
mel_spec_params = MelSpecParameters(n_fft=1024, f_min=0, hop_length=256, n_mels=n_mels, power=1.0, pad=0)
mel_spec_converter = SimpleMelSpecConverter(mel_spec_params)

mel_spec_params_2 = MelSpecParameters(n_fft=2048, f_min=0, hop_length=512, n_mels=128, power=1.0, pad=0)
mel_spec_converter_2 = SimpleMelSpecConverter(mel_spec_params_2)

mel_spec_params_3 = MelSpecParameters(n_fft=4096, f_min=0, hop_length=1024, n_mels=256, power=1.0, pad=0)
mel_spec_converter_3 = SimpleMelSpecConverter(mel_spec_params_3)

lin_vector = torch.linspace(
    1.0,
    1.0,
    n_mels,
)
eye_mat = torch.diag(lin_vector).to(device)

### WAV Player element

In [None]:
# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
                of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
        <source src="files/%s" type="audio/mp3" />
        Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

# Data Sample Processing

### Create a Data Sample

In [None]:
vocab_size = cfg_tokenizer.model.vq_module.token_dim
num_rq_steps = cfg_tokenizer.model.vq_module.num_rq_steps
seq_length = 512
num_seq = 4
sample = torch.randint(0, vocab_size, (num_seq, seq_length, num_rq_steps)).to(device) 

### Generate Wave From The Random Data

In [None]:
generated_waveform: torch.Tensor = model_tokenizer.model.from_tokens(sample)
image_reshaped = generated_waveform.flatten().cpu().detach().numpy().reshape(-1)

plt.figure(figsize=(30, 5))
plt.plot(image_reshaped)

plt.tight_layout()
plt.ylim(-1.2, 1.2)
plt.show()

plt.figure(figsize=(30, 5))
# plt.matshow(torch.tanh(eye_mat @ mel_spec_converter.convert(torch.tensor(image_reshaped))).cpu().numpy(),
#             origin='lower', aspect='auto', vmin=0, vmax=1)
plt.matshow(
    log_normal(mel_spec_converter.convert(torch.tensor(image_reshaped))).cpu().numpy(),
    origin="lower",
    aspect="auto",
    vmin=-2,
    vmax=2,
)
plt.show()

plt.matshow(
    log_normal(mel_spec_converter_2.convert(torch.tensor(image_reshaped))).cpu().numpy(),
    origin="lower",
    aspect="auto",
    vmin=-2,
    vmax=2,
)
plt.show()

plt.matshow(
    log_normal(mel_spec_converter_3.convert(torch.tensor(image_reshaped))).cpu().numpy(),
    origin="lower",
    aspect="auto",
    vmin=-2,
    vmax=2,
)
plt.show()

torchaudio.save('sample.mp3', generated_waveform.flatten().unsqueeze(0).cpu().detach(), 44100, format='mp3') # type: ignore
IPython.display.Audio(filename="sample.mp3") # type: ignore

# Generated Data Processing

In [None]:
model_dllm.generate()