# Notebook for testing tokanization reconstruction

### Change working directory to this root

In [1]:
import os
os.chdir("..")

### Imports


In [2]:
import random

import torch
import matplotlib.pyplot as plt
import torchaudio
import IPython

from common import registry
from utils.config import load_cfg_from_hydra
from utils.containers import MelSpecParameters
from models.mel_spec_converters import SimpleMelSpecConverter
from models.base import Tokenizer


  from .autonotebook import tqdm as notebook_tqdm


### Load tokenizer model

In [3]:
# weights path
weights_path = "weights/tokenizer_best.ckpt"

# Infer device
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Current device is {device}")

# Load network
cfg_tokenizer = load_cfg_from_hydra(config_path="../config", config_name="config")
cfg_mamba = load_cfg_from_hydra(config_path="../config", config_name="mamba")
cfg_mamba.dataset.index_series_length = 1024
cfg_tokenizer.learning.batch_size = 128
cfg_tokenizer.learning.val_split = 1.0 # Will it let me do that?
module = registry.get_lightning_module(cfg_tokenizer.model.module_type).from_cfg(cfg_tokenizer, weights_path).to(device).eval() # type: ignore

Current device is cpu


### Load tokenized data loader

In [4]:
dataset = registry.get_dataset("mp3_indices").from_cfg(cfg_mamba)
dataset_len = len(dataset)

dataset_slice = {"indices": torch.tensor(cfg_tokenizer.model.vocabulary_size).unsqueeze(0)}

# Select a random sample
while sum(dataset_slice["indices"] == cfg_tokenizer.model.vocabulary_size) > 0 or\
    sum(dataset_slice["indices"] == cfg_tokenizer.model.vocabulary_size + 1) > 0:
        
    sample_start = random.randint(0, dataset_len - 1)
    dataset_slice = dataset[sample_start] # type: ignore
    indices_sample = dataset_slice["indices"].view(-1, 4)
    
print(f"Indices sample size: {indices_sample.size()}")

Indices sample size: torch.Size([256, 4])


### WAV player element

In [5]:
# this is a wrapper that take a filename and publish an html <audio> tag to listen to it

def wavPlayer(filepath):
    """ will display html 5 player for compatible browser

    Parameters :
    ------------
    filepath : relative filepath with respect to the notebook directory ( where the .ipynb are not cwd)
                of the file to play

    The browser need to know how to play wav through html5.

    there is no autoplay to prevent file playing when the browser opens
    """
    
    src = """
    <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Simple Test</title>
    </head>
    
    <body>
    <audio controls="controls" style="width:600px" >
        <source src="files/%s" type="audio/mp3" />
        Your browser does not support the audio element.
    </audio>
    </body>
    """%(filepath)
    display(HTML(src))

### Initialize Mel Spectrogram converter

In [6]:
n_mels = 32
mel_spec_params = MelSpecParameters(n_fft=1024, f_min=0, hop_length=64, n_mels=n_mels, power=1.0, pad=0)
mel_spec_converter = SimpleMelSpecConverter(mel_spec_params)

lin_vector = torch.linspace(
    1.0,
    1.0,
    n_mels,
)
eye_mat = torch.diag(lin_vector).to(device)

## Rebuild Tokenized Slice

In [18]:
print(indices_sample.size())
tokenizer_output = module.model.from_tokens(indices_sample.unsqueeze(-1))

torch.Size([256, 4])


RuntimeError: Given groups=1, weight of size [16, 16, 3], expected input[256, 4, 16] to have 16 channels, but got 4 channels instead