# Auto-Regressive Chaos Hydra Test Notebook

### Change working directory to this root

In [None]:
import os
os.chdir("..")

### Imports

In [None]:
import IPython

import torch
import matplotlib.pyplot as plt
import torchaudio
import numpy as np

from common import registry
from utils.config import load_cfg_from_hydra
from models.pipeline import AutoRegressivePipeline

### Load Model

In [None]:
# Tokenizer weights path
tokenizer_path = "weights/tokenizer_best.ckpt"
mamba_path = "weights/mamba_best.ckpt"

# Infer device
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print(f"Current device is {device}")

# Load tokenizer
cfg = load_cfg_from_hydra(config_path="../config", config_name="config")
cfg.learning.batch_size = 512
cfg.learning.val_split = 1.0 # Will it let me do that?
tokenizer = registry.get_lightning_module(cfg.model.module_type).from_cfg(cfg, tokenizer_path).to(device).eval()

# Load Chaos Hydra
cfg_hydra = load_cfg_from_hydra(config_path="../config", config_name="mamba")
cfg_hydra.learning.batch_size = 512
cfg_hydra.learning.val_split = 1.0 # Will it let me do that?
chaos_hydra = registry.get_lightning_module(cfg_hydra.model.module_type).from_cfg(cfg_hydra, mamba_path).to(device).eval()


### Wav Player Element

In [None]:
# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
                of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
        <source src="files/%s" type="audio/mp3" />
        Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

## Generate music from random initial point

In [None]:
slice_length = 8192

pipeline = AutoRegressivePipeline(tokenizer.model, chaos_hydra) # type: ignore
# random_initial_point = torch.randint(0, cfg_hydra.model.vocabulary_size, (1, 1, 16)).to(device)
random_initial_point = torch.ones((1, 1, 16)).to(device).int() * 1024

with torch.no_grad():
    series = pipeline.create_fixed_music_slice(random_initial_point, slice_length, top_k=5, temperature=0.6) 
print(series.shape)

### Display and play

In [None]:
image_reshaped = series.flatten().cpu().numpy() # type: ignore
print(series.shape)

plt.figure(figsize=(30, 5))
plt.plot(image_reshaped)

plt.tight_layout()
plt.ylim(-1.2, 1.2)
plt.show()

torchaudio.save('sample.mp3', series.cpu().detach(), 44100, format='mp3') # type: ignore
IPython.display.Audio(filename="sample.mp3") # type: ignore