# Infected HDiff Res1D Inference Notebook

## Setup and Imports

### Change working directory

In [None]:
import os
os.chdir("..")

### Imports

In [None]:
import torch
import random
import seaborn as sns
import numpy as np
import pandas as pd
import torchaudio
import IPython

from loaders.build import build_mel_spec_module
from models.build import build_res1d_vocoder, build_mel_spec_converter
from utils.containers import MelSpecParameters, LearningParameters, DiffusionParameters, MusicDatasetParameters, parse_cfg_for_vocoder
from utils.others import load_config

### Load configuration

In [None]:
cfg_path = "config/vocoders/config_res1d.yaml"
cfg = load_config(cfg_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The current device is {device.upper()}")

weights_path = "weights/vocoder/res1d_best.ckpt"

### Load model

In [None]:
vocoder = build_res1d_vocoder(cfg, loss_aggregator=None, weights_path=weights_path).eval().to(device)
print(f"Loaded vocoder from path {weights_path}")

### Load dataset

In [None]:
dataset_params = MusicDatasetParameters(**cfg["dataset"])
learning_params = LearningParameters(**cfg["learning"])
learning_params.batch_size = 1
mel_spec_params = MelSpecParameters(**cfg["image_mel_spec_params"])
data_module = build_mel_spec_module(dataset_params, learning_params, mel_spec_params)

### Load data-point

In [None]:
data_module.setup("fit")
loader = data_module.train_dataloader()
for batch in loader:
    print(f"File name: {batch['slice_file_name'][0]}")
    print(f"File index: {batch['slice_idx'].item()}")
    print(f"The slice dimensions are {tuple(batch['slice'].shape)}")
    slice = batch["slice"]
    break

### Create player object

In [None]:
# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
               of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
      <source src="files/%s" type="audio/mp3" />
      Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

## Present and Process the Slice

### Show Mel-Spectrogram

In [None]:
mel_spec_converter = build_mel_spec_converter("simple", mel_spec_params)
slice_mel_spec = mel_spec_converter.convert(slice)
sns.set(rc={'figure.figsize':(20, 5)})
ax = sns.heatmap(slice_mel_spec.squeeze(0).squeeze(0).cpu().numpy(), robust=True, cmap="mako")
ax.invert_yaxis()

### Show waveform and play it

In [None]:
x = (np.array(range(slice.shape[-1])) + 1) / dataset_params.sample_rate
y = slice.squeeze(0).squeeze(0).cpu().numpy()
data = pd.DataFrame({"Time [s]": x, "Signal": y})
sns.set_style("darkgrid")
sns.set(rc={'figure.figsize':(20, 5)})
ax = sns.lineplot(data=data, x="Time [s]", y="Signal")
ax.set_ylim((-1, 1))
torchaudio.save('sample_out.mp3', slice.squeeze(0).cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")

### Reconstruct the sound-wave

In [None]:
inputs = {"mel_spec": slice_mel_spec.squeeze(1).to(device)}
with torch.no_grad():
    outputs = vocoder.forward(inputs)
output_slice = outputs["slice"].squeeze(1)

### Play the reconstructed slice

In [None]:
x = (np.array(range(slice.shape[-1])) + 1) / dataset_params.sample_rate
y = output_slice.squeeze(0).cpu().numpy()
data = pd.DataFrame({"Time [s]": x, "Signal": y})
sns.set(rc={'figure.figsize':(20, 5)})
ax = sns.lineplot(data=data, x="Time [s]", y="Signal")
ax.set_ylim((-1, 1))
torchaudio.save('sample_out.mp3', output_slice.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")

## Reconstruct Unknown Slice

### Extract novel slice

In [None]:
test_track_path = "data/test_tracks/bhaalspawn.mp3"
slice_new_big: torch.Tensor = torchaudio.load(test_track_path, format="mp3")[0]
slice_new_big = slice_new_big.mean(dim=0, keepdim=True)
slice_start_idx = int(random.random() * 8e6)
slice_new = slice_new_big[:, slice_start_idx: slice_start_idx + 32768 * 8]

### Play the slice and show the waveform

In [None]:
x = (np.array(range(slice_new.shape[-1])) + 1) / dataset_params.sample_rate
y = slice_new.squeeze(0).cpu().numpy()
data = pd.DataFrame({"Time [s]": x, "Signal": y})
sns.set_style("darkgrid")
sns.set(rc={'figure.figsize':(20, 5)})
ax = sns.lineplot(data=data, x="Time [s]", y="Signal")
ax.set_ylim((-1, 1))

torchaudio.save('sample_out.mp3', slice_new.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")

### Display mel spectrogram

In [None]:
slice_mel_spec = mel_spec_converter.convert(slice_new)
ax = sns.heatmap(slice_mel_spec.squeeze(0).squeeze(0).cpu().numpy(), robust=True, cmap="mako")
ax.invert_yaxis()

### Pass the mel-spectrogram through the decoder

In [None]:
inputs = {"mel_spec": slice_mel_spec.squeeze(1).to(device)}
with torch.no_grad():
    outputs = vocoder.forward(inputs)
output_slice = outputs["slice"].squeeze(1)

### Play the reconstructed slice

In [None]:
x = (np.array(range(output_slice.shape[-1])) + 1) / dataset_params.sample_rate
y = output_slice.squeeze(0).cpu().numpy()
data = pd.DataFrame({"Time [s]": x, "Signal": y})
sns.set(rc={'figure.figsize':(20, 5)})
ax = sns.lineplot(data=data, x="Time [s]", y="Signal")
ax.set_ylim((-1, 1))
torchaudio.save('sample_out.mp3', output_slice.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")

# Random Slice Generation

### Create random slice

In [None]:
slice_mel_spec = (torch.randn((1, 512, 32)) @ torch.diag(torch.linspace(1, 0.01, 32))).permute((0, 2, 1))
slice_mel_spec[slice_mel_spec < 0] = 0
ax = sns.heatmap(slice_mel_spec.squeeze(0).squeeze(0).cpu().numpy(), robust=True, cmap="mako")
ax.invert_yaxis()

### Pass the mel-spectrogram through the decoder

In [None]:
inputs = {"mel_spec": slice_mel_spec.squeeze(1).to(device)}
with torch.no_grad():
    outputs = vocoder.forward(inputs)
output_slice = outputs["slice"].squeeze(1)

### Display the waveform and play the sound

In [None]:
x = (np.array(range(output_slice.shape[-1])) + 1) / dataset_params.sample_rate
y = output_slice.squeeze(0).cpu().numpy()
data = pd.DataFrame({"Time [s]": x, "Signal": y})
sns.set(rc={'figure.figsize':(20, 5)})
ax = sns.lineplot(data=data, x="Time [s]", y="Signal")
ax.set_ylim((-1, 1))
torchaudio.save('sample_out.mp3', output_slice.cpu().detach(), 44100, format='mp3')
IPython.display.Audio(filename="sample_out.mp3")