----

# <center><b>Audio ResNet - LatentPlay</center>
# <center><b><span style="color:red;">Kick neural synthesis</b></span></center>
----
## Adhémar DE SENNEVILLE | adhemar.de_senneville@ens-paris-saclay.fr
-----

# Initialization

In [1]:
RUN_NAME = 'RUN_11_Deeper_05_Compression' # WandB run name
RUN_TEST = False # Do a fast run to check the pipeline is working
RUN_1_BATCH = False # Do a fast run to check the pipeline is working
MAX_EPOCH = 1000

PROJECT = 'Latent_Play' # WandB project name, monitor audio outputs in real time
HARDWARE = 'P100' # T4 or P100 or CPU
DATA_SAVE = 'data' # Directory to save data

In [2]:
# Imports
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import random
import time
import sys
import gc
import yaml
import pickle
from math import floor
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

# Imports Torch
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader, TensorDataset, Dataset
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl

# Imports Autre
import librosa
import matplotlib.pyplot as plt
import scipy.io.wavfile as wav
import scipy.signal as signal
import IPython.display as ipd
from matplotlib import pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")
os.makedirs('data', exist_ok=True)
os.makedirs('fig', exist_ok=True)

### Kaggle setup

In [3]:
KAGGLE = os.getcwd() == "/kaggle/working"

# Avoid pip install att each re-run
try:
    print("Not first session", FIRST_RUN)
except:
    print("First session")
    FIRST_RUN = True

if KAGGLE and FIRST_RUN:
    print("Notebook running on Kaggle")
    PATH = '/kaggle/input/kick-wav/kick_dataset'
    
    #!pip install -U 'wandb>=0.12.10' # NOT WORKING
    #!pip uninstall -y wandb
    #!pip install -U 'wandb==0.17.0' # Only version compatible...
    !pip install auraloss
    !pip install audio-encoders-pytorch
    FIRST_RUN = False
else:
    PATH = '../Dataset/kick_dataset' # Setup the path to you dataset
    print("Notebook running on local")

First session
Notebook running on Kaggle
Collecting auraloss
  Downloading auraloss-0.4.0-py3-none-any.whl.metadata (8.0 kB)
Downloading auraloss-0.4.0-py3-none-any.whl (16 kB)
Installing collected packages: auraloss
Successfully installed auraloss-0.4.0
Collecting audio-encoders-pytorch
  Downloading audio_encoders_pytorch-0.0.22-py3-none-any.whl.metadata (783 bytes)
Collecting data-science-types>=0.2 (from audio-encoders-pytorch)
  Downloading data_science_types-0.2.23-py3-none-any.whl.metadata (5.4 kB)
Collecting einops>=0.6 (from audio-encoders-pytorch)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting einops-exts>=0.0.3 (from audio-encoders-pytorch)
  Downloading einops_exts-0.0.4-py3-none-any.whl.metadata (621 bytes)
Downloading audio_encoders_pytorch-0.0.22-py3-none-any.whl (9.6 kB)
Downloading data_science_types-0.2.23-py3-none-any.whl (42 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.7/42.7 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00

### WandB setup

In [7]:
# Put you WandB API Key here
from kaggle_secrets import UserSecretsClient
API_KEY = UserSecretsClient().get_secret("WandB_API_Key")
os.environ["WANDB_API_KEY"] = API_KEY

### Seeding

In [None]:
from pytorch_lightning import seed_everything
seed_everything(42, workers=True) # sets seeds for numpy, torch and python.random.

----
# Utils
Don't need to change that code

In [None]:
# Compute the max fft frequency (temporal 0 padding could be added for more precision)
def get_max_frequency(x, sr):
    with torch.no_grad():
        fft_result = torch.fft.fft(x)
        power_spectrum = torch.abs(fft_result) ** 2
        
        # Get the frequency bin with the highest power
        max_power_index = torch.argmax(power_spectrum)
        max_frequency = max_power_index * sr / x.size(-1)
        
        return (max_frequency).item()

# Compute the time pourcentage to get to the max amplitude
def get_attack(x):
    attack_index = torch.argmax(x.abs())
    return (attack_index/x.size(-1)).item()

# Compute the time pourcentage to get below treshold
def get_release(x, threshold = 0.1):
    
    #print(x.shape)
    x_end = x[0,:x.size(-1)//2] # min release at 50%
    #print(x_end.shape)
    indices_above_threshold = torch.where(x_end.abs() > threshold)[0]
    #print(indices_above_threshold)
    release_index = indices_above_threshold[-1]
    return 0.5 - (release_index / x.size(-1)).item()

def clip(value, min_value, max_value):
    return torch.clip(torch.tensor(value), min_value, max_value)

## Data

In [None]:
class KickDataset(Dataset):
    def __init__(self, path = "", sr = 44100, duration = 1, fade_out = 0.1):
           
        # Init variables
        self.path = path
        self.sr = sr
        self.duration = duration
        self.fade_out = fade_out
        self.sample_length = int(sr * duration)
        self.wav_files = self._create_index_table()

    def _create_index_table(self):
        wav_files = []
        for root, _, files in os.walk(self.path):
            for file in files:
                if file.endswith('.wav'):
                    relative_path = os.path.join(root, file)
                    wav_files.append(relative_path)
        return wav_files
    
    def load_all(self, limit=None):
        # Load all dataset in local memory, util the limit
        limit = len(self) if limit is None else min(limit, len(self))
        self.data = torch.zeros((limit, 1, self.sample_length))
        self.features = torch.zeros((limit, 3))

        for i in range(limit):
            x = self.load_wav(self.wav_files[i])
            self.data[i] = x
            
            # Normalize and add features
            self.features[i][0] = clip(get_max_frequency(x, self.sr),0,150)/150
            self.features[i][1] = clip(get_attack(x),0,0.5)/0.5
            self.features[i][2] = clip(get_release(x),0,0.5)/0.5

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        return self.data[idx], self.features[idx]
    
    def load_wav(self, relative_path):
        # Load audio file
        waveform, sample_rate = torchaudio.load(relative_path, normalize=True)

        # Ensure mono by averaging channels if necessary
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if sample rate is different from cfg
        if sample_rate != self.sr:
            transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)
            waveform = transform(waveform)

        # Normalize the waveform
        waveform = waveform / waveform.abs().max()

        # Pad or trim to the desired sample length
        if waveform.size(1) < self.sample_length:
            pad_size = self.sample_length - waveform.size(1)
            waveform = torch.nn.functional.pad(waveform, (0, pad_size))
        elif waveform.size(1) > self.sample_length:
            waveform = waveform[:, :self.sample_length]
        
         # Apply linear fade-out to the last 10% of the audio
        fade_out_length = int(self.sample_length * 0.1)
        fade_out = torch.linspace(1, 0, fade_out_length)
        waveform[:, -fade_out_length:] *= fade_out

        return waveform

    def show_sample(self, idx):
        waveform = self.load_wav(self.wav_files[idx])
        print("Name  :",self.wav_files[idx])
        print("Shape :",waveform.shape)
        plt.figure(figsize=(10, 4))
        plt.plot(waveform.t().numpy())
        plt.title(f"Waveform of sample at index {idx}")
        plt.xlabel("Time")
        plt.ylabel("Amplitude")
        plt.show()



## Loss

In [None]:
import auraloss
import torch.nn as nn

# I improve audio quality (from simple MSE used before)
# By using this Multiresolution Time Frequency loss
# It reduces the high frequency flickering I noticed using MSE Loss (time domain loss)

class TimeFrequencyLoss(nn.Module):
    def __init__(self, alpha, tau, gain, sr, duration):
        super().__init__()
        
        # Sample size ~6000
        self.frequ_loss = auraloss.freq.MultiResolutionSTFTLoss(
            fft_sizes=[32, 128, 512, 2048], #[32, 128, 512, 2048, 8192, 32768]
            hop_sizes=[16, 64, 256, 1024], #[16, 64, 256, 1024, 4096, 16384]
            win_lengths=[32, 128, 512, 2048], #[32, 128, 512, 2048, 8192, 32768]
            w_sc=0.0,
            w_phs=0.0,
            w_lin_mag=1.0,
            w_log_mag=1.0,
        )
        self.time_loss = nn.MSELoss()
        self.alpha = alpha
        
        length = int(sr * duration)
        t = torch.linspace(0,1,length)
        self.enveloppe = 1 + gain * torch.exp(-t/tau)
        

    def forward(self, y_hat, y):
        enveloppe = self.enveloppe.to(y_hat.device)
        
        y_hat_mod = y_hat * enveloppe
        y_mod = y * enveloppe
        
        
        # Calculate frequency domain loss
        f_loss = self.frequ_loss(y_hat, y)
        
        # Calculate time domain loss
        t_loss = self.time_loss(y_hat, y)
        
        # Combine the losses
        total_loss = f_loss + self.alpha * t_loss
        return total_loss

## Model

In [None]:
from audio_encoders_pytorch import AutoEncoder1d #from model import AutoEncoder1d

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, model_cfg, training_cfg, data_cfg):
        super(LitAutoEncoder, self).__init__()
        self.save_hyperparameters() # for wandb
        
        # High level features of the AutoEncoder
        embedded_length = int(data_cfg['duration'] * data_cfg['sr'])
        for factor in model_cfg['factors']:
            embedded_length = int((embedded_length-1)/factor)+1

        input_channels = model_cfg['channels']
        for multiplier in model_cfg['multipliers']:
            input_channels = input_channels // multiplier
        
        # Dim values
        self.length = int(data_cfg['duration'] * data_cfg['sr'])
        self.compression_rate = model_cfg['compression_rate']
        self.conv_out_channels = model_cfg['channels']
        self.conv_out_length = embedded_length
        self.conv_out_dim = self.conv_out_channels * self.conv_out_length
        self.latent_dim = int(self.length * self.compression_rate)

        # Model initialization
        self.model = AutoEncoder1d(
            in_channels=model_cfg['in_channels'],
            channels=model_cfg['channels'],
            multipliers=model_cfg['multipliers'],
            factors=model_cfg['factors'],
            num_blocks=model_cfg['num_blocks']
        )
        # Dense layers
        self.encode_linear = nn.Linear(self.conv_out_dim, self.latent_dim)
        self.decode_linear = nn.Linear(self.latent_dim, self.conv_out_dim)
        self.features_linear = nn.Linear(self.latent_dim, 3)
        
        # Model infos
        print(f'{"Global Compression Rate":>30} : {100 * self.compression_rate:.2f} % (~{int(0.5+1/self.compression_rate)})')
        print(f'{"Convolution Compression Rate":>30} : {100 * self.conv_out_dim / self.length:.2f} % (~{int(0.5+self.length/self.conv_out_dim)})')
        print(f'{"Dense Compression Rate":>30} : {100 * self.latent_dim / self.conv_out_dim:.2f} % (~{int(0.5+self.conv_out_dim/self.latent_dim)})')
        print(f'{"Input Shape":>30} : ({input_channels},{self.length})')
        print(f'{"Conv Latent Shape":>30} : ({self.conv_out_channels},{self.conv_out_length}) -> {self.conv_out_dim}')
        print(f'{"Latent Shape":>30} : ({self.latent_dim})')
        print(f'{"Parameter Number":>30} : ({self.count_parameters()})') 
        print(f'{"Encoder + Decoder are":>30} : Resnet-{int(np.sum(model_cfg["num_blocks"])):_}')
        
        # Training init
        self.lr = training_cfg['lr']
        self.sr = data_cfg['sr']
        self.audio_loss = TimeFrequencyLoss(**training_cfg['audio_loss_params'],
                                         sr = training_cfg['sr'],
                                         duration = training_cfg['duration'])
        self.beta = training_cfg['features_loss_params']['beta']
        self.features_loss = nn.MSELoss()

        # Placeholder for the first batch, first audio samples
        self.first_audio_sample = None
        
    def forward(self, x):
        z = self.forward_encode(x)
        f = self.forward_features(z)
        x = self.forward_decode(z)
        return x, z, f
    
    def forward_features(self, z):
        f = self.features_linear(z)
        return f
    
    def forward_encode(self, x):
        z = self.model.encode(x)
        z = z.flatten(1)
        z = self.encode_linear(z)
        return z
    
    def forward_decode(self,z):
        z = self.decode_linear(z)
        z = z.view(z.shape[0],self.conv_out_channels,self.conv_out_length)
        x = self.model.decode(z)
        return x[..., :self.length]
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def training_step(self, batch, batch_idx):
        x, f = batch
        
        # Log the original audio
        if self.first_audio_sample is None: 
            self.first_audio_sample = x[(0,),...]
        
        # Compute loss
        x_hat, z, f_hat = self.forward(x)
        audio_loss = self.audio_loss(x_hat, x)
        features_loss = self.beta * self.features_loss(f_hat, f)
        total_loss = audio_loss+features_loss
        
        # Log
        self.log_dict({'train_loss': total_loss, 
                       'audio_loss': audio_loss,
                       'features_loss':features_loss,
                      }) #, on_step=True, on_epoch=True
        return total_loss
    
    def predict_step(self, x, batch_idx):
        x_hat = self.forward(x)
        return x_hat

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def on_train_epoch_end(self):        
        #print('here')
        
        if (self.first_audio_sample is not None) and (self.current_epoch % 10 == 1) and hasattr(self.logger.experiment, 'log'):
            
            # Get the first audio sample
            #print(self.first_audio_sample.device)
            original_audio = self.first_audio_sample[0].cpu().numpy()
            reconstructed_audio = self.forward(self.first_audio_sample)[0][0].cpu().detach().numpy()
            
            # Ensure the audio data is in the correct range and format
            if original_audio.dtype != 'float32':
                original_audio = original_audio.astype('float32')
            if reconstructed_audio.dtype != 'float32':
                reconstructed_audio = reconstructed_audio.astype('float32')

            # Normalize audio to be in the range -1.0 to 1.0
            # original_audio /= np.max(np.abs(original_audio), axis=-1, keepdims=True)
            # reconstructed_audio /= np.max(np.abs(reconstructed_audio), axis=-1, keepdims=True)
            # print(original_audio.shape)
            # print(reconstructed_audio.shape)
            
            # Log Audios
            if self.current_epoch==1:
                self.logger.experiment.log({
                    "original_audio": wandb.Audio(original_audio[0], sample_rate=self.sr, caption="Original Audio"),
                    "epoch": self.current_epoch
                })
            self.logger.experiment.log({
                "reconstructed_audio": wandb.Audio(reconstructed_audio[0], sample_rate=self.sr, caption="Reconstructed Audio"),
                "epoch": self.current_epoch
            })
            
            ori_path = 'fig/original_spectrogram.png'
            rec_path = 'fig/reconstructed_spectrogram.png'
            # Compute spectrograms to decibel (dB) units
            original_spectrogram = librosa.stft(original_audio[0])
            reconstructed_spectrogram = librosa.stft(reconstructed_audio[0])
            original_spectrogram_db = librosa.amplitude_to_db(np.abs(original_spectrogram), ref=np.max)
            reconstructed_spectrogram_db = librosa.amplitude_to_db(np.abs(reconstructed_spectrogram), ref=np.max)

            # Log Spectrograms
            if self.current_epoch == 1:
                plt.figure(figsize=(5, 5))
                librosa.display.specshow(original_spectrogram_db, sr=self.sr, x_axis='time', y_axis='log')
                plt.colorbar(format='%+2.0f dB')
                plt.title('Original Spectrogram')
                plt.tight_layout()
                plt.savefig(ori_path)
                plt.close()

                self.logger.experiment.log({
                    "original_spectrogram": wandb.Image(ori_path),
                    "epoch": self.current_epoch
                })
            
            plt.figure(figsize=(5, 5))
            librosa.display.specshow(reconstructed_spectrogram_db, sr=self.sr, x_axis='time', y_axis='log')
            plt.colorbar(format='%+2.0f dB')
            plt.title('Reconstructed Spectrogram')
            plt.tight_layout()
            plt.savefig(rec_path)
            plt.close()

            self.logger.experiment.log({
                "reconstructed_spectrogram": wandb.Image(rec_path),
                "epoch": self.current_epoch
            })



### Testing model inference

In [None]:
if False:
    '''model_cfg = {
        'in_channels': 1,
        'channels': 64,
        'multipliers': [1, 2, 2, 2, 2, 1, 1],
        'factors':     [4, 4, 4, 4, 4, 2],
        'num_blocks':  [5, 5, 5, 5, 5, 5],
        'in_length': 6615,
        'compression_rate': 0.01,
    }

    training_cfg = {
        # Learning
        'epoch': 400,
        'epoch_min': 50,
        'patience': 50,
        'lr': 1e-4,

        # Loss
        'audio_loss_params':{
            'alpha': 100,
            'tau': 0.1,
            'gain': 5,
        },
        'features_loss_params':{
            'beta':50,
        },

        # Data
        'sr': 22050,
        'duration': 0.3,

        # Other
        'batch_size': 32,
        'num_workers': 3,
        'best_model_path': None,
        'hardware': HARDWARE,
        'machine': 'Kaggle' if KAGGLE else 'PC'
    }
    
    data_cfg = {
        'path': PATH,
        'sr': 22050,
        'duration': 0.3, # seconds
        'fade_out': 0.1 # 10% fade out
    }

    # Initialize the model
    model = LitAutoEncoder(model_cfg, training_cfg, data_cfg)

    # Generate random input
    input_tensor_1 = torch.randn(32, 1, 6615)
    input_tensor_2 = torch.randn(32, 3)

    # Quick inference encoder
    laten = model.model.encoder(input_tensor_1)
    print(f'{laten.shape = }')
    
    # Quick inference auto-encoder
    output = model(input_tensor_1)#, input_tensor_2
    print(f'{output[0].shape = }')'''

----
# Training Configuration

In [None]:
training_cfg = {
    # Learning
    'epoch': MAX_EPOCH,
    'epoch_min': 100,
    'patience': 300,
    'lr': 3e-5,
    
    # Loss
    'audio_loss_params':{
        'alpha': 100,
        'tau': 0.1, # See README
        'gain': 5,  # See README
    },
    'features_loss_params':{
        'beta':50,
    },
    
    # Data
    'sr': 22050, # per seconds
    'duration': 0.3, # in seconds
    
    # Other
    'batch_size': 32,
    'num_workers': 3,
    'best_model_path': None,
    'hardware': HARDWARE,
    'machine': 'Kaggle' if KAGGLE else 'PC'
}


## Data Config

In [None]:
data_cfg = {
    'path': PATH, # Path to dataset
    'sr': training_cfg['sr'],
    'duration': training_cfg['duration'], 
    'fade_out': 0.1 # 10% fade out
}

# Create dataset instance
dataset = KickDataset(**data_cfg)

In [None]:
dataset.load_all()
memory_usage_mb = sys.getsizeof(dataset.data.untyped_storage()) / (1024 ** 2)
print(f'Dataset Memory usage: {memory_usage_mb:.2f} MB')
print(f'Dataset Shape       : {dataset.data.shape}')
training_cfg['time_length'] = dataset.data.shape[-1]

In [None]:
# Warning, No valdation, No test... the objectiv is to overfit 
dataloader = DataLoader(dataset, 
                        batch_size=training_cfg['batch_size'], 
                        shuffle=True, 
                        num_workers=training_cfg['num_workers'])

## Model Config

In [None]:
# Configuration dictionaries
model_cfg = {
        'in_channels': 1,
        'channels': 64,
        'multipliers': [1, 2, 2, 2, 2, 1, 1],
        'factors':     [4, 4, 4, 4, 4, 2],
        'num_blocks':  [6, 6, 6, 6, 6, 6],
        'compression_rate': 0.005, # 0.5% Compression
    }
# Initialize model
autoencoder = LitAutoEncoder(model_cfg, training_cfg, data_cfg)

## Training Configs

In [None]:
if KAGGLE:
    # Kaggle herdware
    if HARDWARE == 'T4':
        num_gpus = 2
        available_devices = [0,1]
        print('On Kaggle Double T4')
    elif HARDWARE == 'P100':
        num_gpus = 1
        available_devices = [0]
        print('On Kaggle P100')
    else:
        num_gpus = 0
        available_devices = ['CPU']
        print('On Kaggle CPU')
else:
    num_gpus = 0 # No gpu on my computer :(
    available_devices = ['CPU']
    print('On laptop CPU')

accelerator = 'gpu' if num_gpus > 0 else 'cpu'
devices = num_gpus if num_gpus > 0 else 1

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

early_stop_callback = EarlyStopping(
    monitor='train_loss',  
    patience=training_cfg['patience'],          
    verbose=True,        
    mode='min'           
)

checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',  # Change this to 'val_loss' if you have validation set
    dirpath=f'{DATA_SAVE}/',
    filename='best-checkpoint',
    save_top_k=1,
    mode='min'
)

trainer_cfg = {
    
    # Hardware
    'accelerator': accelerator,
    'devices': devices,
    'num_nodes': 1,
    'precision': 32, # Float32
    'deterministic': False, # Increase training speed for our fixed tensor size dataset
    'benchmark': True, # Increase training speed for our fixed tensor size dataset
    
    # Epochs
    'min_epochs': training_cfg['epoch_min'],
    'max_epochs': training_cfg['epoch'],
    'max_time': '00:12:00:00',
    'accumulate_grad_batches': 1,
    'callbacks': [early_stop_callback, checkpoint_callback],
    
    # Logging / Debug
    'logger': None,  # Defined later
    'profiler': None,   # Defined later
    'fast_dev_run': RUN_TEST,
    'limit_train_batches': 1 if RUN_1_BATCH else None,
    'enable_checkpointing': True,
    'barebones': False,
}

## Logger Config

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger
if False: # Bebug WAndB
    
    #os.environ["WANDB_HTTP_TIMEOUT"] = "300"  # Useless
    #os.environ["WANDB_MODE"] = "offline"      # Not working
    import wandb
    wandb.login()
    wandb.init(project=PROJECT, config=all_config)
    wandb.finish()
    
    from lightning_utilities.core.imports import RequirementCache
    _WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")
    _WANDB_AVAILABLE
    
    !ls /opt/conda/lib/python3.10/site-packages/wandb-0.17.0.dist-info
    #!sudo /opt/conda/lib/python3.10/site-packages/wandb-0.17.0.dist-info/METADATA

In [None]:
all_config = {
    'TRAINING': training_cfg,
    'MODEL': model_cfg,
    'DATA': data_cfg,
    'TRAINER': trainer_cfg,
}

profiler = 'simple' # Track training performance
logger = None if (RUN_TEST or RUN_1_BATCH) else WandbLogger(project=PROJECT, name=RUN_NAME, config=all_config)

trainer_cfg['profiler'] = profiler
trainer_cfg['logger'] = logger

----
# Run

In [None]:
# Initialize model
autoencoder = LitAutoEncoder(model_cfg, training_cfg, data_cfg)

# Initialize trainer
trainer = Trainer(**trainer_cfg)

In [None]:
# Here we go :
trainer.fit(autoencoder, dataloader)
wandb.finish()

In [None]:
wandb.finish()

In [None]:
if False:    # If you want to pause execution after training
    print("Paused execution. Run the following line to continue.")
    while not input():
        time.sleep(1)

----
# Play with latent

In [None]:
# All parameters used for the generation of the vst
latent_play_parameters = {'sr': data_cfg['sr'],
                         'audio_length': 6615,
                         }

In [None]:
try:
    print(checkpoint_callback.best_model_path)
    autoencoder = LitAutoEncoder.load_from_checkpoint('data/best-checkpoint.ckpt')
except:
    print('Impossible to load checkpoint')

In [None]:
# Load latent dataset
Z = []
for i, batch in enumerate(dataset):
    with torch.no_grad():
        if i % 10 == 0:
            print(round(100 * i / len(dataset), 2), '%', end='\r')
            
        device = autoencoder.device
        x, _ = batch
        x_in = x.unsqueeze(0).to(device)
        encoded_output = autoencoder.forward_encode(x_in).squeeze(0).cpu().detach().numpy()
        Z.append(encoded_output)

# Concatenate all encoded outputs
Z = np.array(Z)
print(f"Encoded Data Shape: {Z.shape}")

## PCA

In [None]:
from sklearn.decomposition import PCA

# Latent PCA
pca = PCA(n_components=2)
principal_components = pca.fit_transform(Z)
z_pca_1, z_pca_2 = pca.components_

print(f'{principal_components.shape = }')
print(f'{z_pca_1.shape = }')

# The dataset is composed of different sample pack
# Each sample pack can be a class
y_classes = [path.split('kick_dataset/')[1].split('/')[0] for path in dataset.wav_files]

In [None]:
# Save first and second components
latent_play_parameters["z_pca_1"] = z_pca_1
latent_play_parameters["z_pca_2"] = z_pca_2

latent_play_parameters["z_pca_1_scale"] = (np.min(principal_components[:,0]),np.max(principal_components[:,0]))
latent_play_parameters["z_pca_2_scale"] = (np.min(principal_components[:,1]),np.max(principal_components[:,1]))

### Clustering

In [None]:
# Nice cluster plot
import matplotlib.colors as mcolors
colors = list(mcolors.CSS4_COLORS.values())

# Create a scatter plot
plt.figure(figsize=(16, 14))
unique_classes = list(set(y_classes))

for i, class_name in enumerate(unique_classes):
    class_indices = [j for j, x in enumerate(y_classes) if x == class_name]
    s = [10] * len(class_indices)
    
    # Plot points for each class
    plt.scatter(principal_components[class_indices, 0], 
                principal_components[class_indices, 1], 
                s, color=colors[i], label=class_name, alpha=1)

plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('2D Scatter Plot of Principal Components with Centroids')
plt.grid(False)
plt.legend()

plt.tight_layout()
plt.savefig('fig/pca_2D_space.png')
plt.show()

In [None]:
# Nice map of packs in latent space
# Create a scatter plot
plt.figure(figsize=(16, 14))
unique_classes = list(set(y_classes))
centroids = []

for i, class_name in enumerate(unique_classes):
    class_indices = [j for j, x in enumerate(y_classes) if x == class_name]
    s = [30] * len(class_indices)
    
    # Plot points for each class
    plt.scatter(principal_components[class_indices, 0], 
                principal_components[class_indices, 1], 
                s, color=colors[i], label=class_name, alpha=1)
    
    # Compute and plot centroids
    centroid = np.mean(principal_components[class_indices], axis=0)
    centroids.append(centroid)
    plt.scatter(centroid[0], centroid[1], s=100, color=colors[i], edgecolors='black', marker='X')
    plt.text(centroid[0], centroid[1], class_name, fontsize=12, ha='center', va='center', 
             color='black', backgroundcolor=colors[i])

plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.xlim((np.min(np.array(centroids)[:,0])*1.2,np.max(np.array(centroids)[:,0])*1.2))
plt.ylim((np.min(np.array(centroids)[:,1])*1.2,np.max(np.array(centroids)[:,1])*1.2))
plt.title('2D Scatter Plot of Principal Components with Centroids')
plt.grid(False)

plt.tight_layout()
plt.savefig('fig/pca_cluster.png')
plt.show()

## Linking *High level features* and *Latent Space*

### Frequency feature

In [None]:
# Plotting the frequency number of predictions in a bar plot
Y_freq = dataset.features[:,0]
unique_freqs, counts = np.unique(Y_freq, return_counts=True)

plt.figure(figsize=(15, 6))  # Make the plot wider
bars = plt.bar(unique_freqs, counts, color='black', edgecolor='black', width=0.02)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Number of Predictions')
plt.title('Frequency Distribution of Predictions')

# Add grid lines
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xlim(0, 1)

plt.tight_layout()
plt.savefig('fig/feature_freq.png')
plt.show()

### Attack feature

In [None]:
# Plotting the attack number of predictions in a bar plot
Y_attack = dataset.features[:,1]
unique_freqs, counts = np.unique(Y_attack, return_counts=True)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))  # Create subplots

# Main plot
ax1.bar(unique_freqs, counts, color='black', edgecolor='black', width=0.001)
ax1.set_xlabel('Attack (sample)')
ax1.set_ylabel('Number of Predictions')
ax1.set_title('Attack Distribution of Predictions')
ax1.grid(axis='y', linestyle='--', alpha=0.7)
ax1.set_xlim(0, 1)

# Zoomed-in plot
ax2.bar(unique_freqs, counts, color='black', edgecolor='black', width=0.0002)
ax2.set_xlabel('Attack (sample)')
ax2.set_ylabel('Number of Predictions')
ax2.set_title('Zoomed-in Attack Distribution')
ax2.grid(axis='y', linestyle='--', alpha=0.7)
ax2.set_xlim(0, 0.1)

plt.tight_layout()
plt.savefig('fig/feature_attack.png')
plt.show()

### Release feature

In [None]:
# Plotting the attack number of predictions in a bar plot
Y_release = dataset.features[:,2]
unique_freqs, counts = np.unique(Y_release, return_counts=True)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))  # Create subplots

# Main plot
ax1.bar(unique_freqs, counts, color='black', edgecolor='black', width=0.001)
ax1.set_xlabel('Release (sample)')
ax1.set_ylabel('Number of Predictions')
ax1.set_title('Frequency Distribution of Predictions')
ax1.grid(axis='y', linestyle='--', alpha=0.7)
ax1.set_xlim(0, 0.2)

# Zoomed-in plot
ax2.bar(unique_freqs, counts, color='black', edgecolor='black', width=0.0003)
ax2.set_xlabel('Release (sample)')
ax2.set_ylabel('Number of Predictions')
ax2.set_title('Zoomed-in Frequency Distribution')
ax2.grid(axis='y', linestyle='--', alpha=0.7)
ax2.set_xlim(0, 0.02)
ax2.set_ylim(0, 20)

plt.tight_layout()
plt.savefig('fig/feature_release.png')
plt.show()

### Post processing based on grafs

In [None]:
# To array
Y_freq_ = np.array(Y_freq)
Y_attack_ = np.array(Y_attack)
Y_release_ = np.array(Y_release)

# Avoide outliers to over influence mse loss fit 
#Y_freq_ = np.clip(np.array(Y_freq_), 0, 150)
#Y_attack_ = np.clip(np.array(Y_attack_), 0, 0.5)
#Y_release_ = np.clip(np.array(Y_release_), 0, 0.5)

# Normalizing
#Y_freq_ = Y_freq_ / np.max(Y_freq_)
#Y_attack_ = Y_attack_ / np.max(Y_attack_)
#Y_release_ = Y_release_ / np.max(Y_release_)

### Linear regrassion

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error

# Convert tensors to numpy arrays
Y_freq_ = Y_freq_.reshape(-1, 1)
Y_attack_ = Y_attack_.reshape(-1, 1)
Y_release_ = Y_release_.reshape(-1, 1)

# Perform linear regression for each parameter
regressor_freq = LinearRegression().fit(Z, Y_freq_)
regressor_attack = LinearRegression().fit(Z, Y_attack_)
regressor_release = LinearRegression().fit(Z, Y_release_)

# Make predictions
Y_freq_pred = regressor_freq.predict(Z)
Y_attack_pred = regressor_attack.predict(Z)
Y_release_pred = regressor_release.predict(Z)

# Get regression vectors
reg_vector_freq = regressor_freq.coef_
reg_vector_attack = regressor_attack.coef_
reg_vector_release = regressor_release.coef_

In [None]:
# Calculate MSE for each parameter
mse_freq = mean_squared_error(Y_freq_, Y_freq_pred)
mse_attack = mean_squared_error(Y_attack_, Y_attack_pred)
mse_release = mean_squared_error(Y_release_, Y_release_pred)

# Calculate MAPE for each parameter
mape_freq = mean_absolute_percentage_error(Y_freq_, Y_freq_pred)
mape_attack = mean_absolute_percentage_error(Y_attack_, Y_attack_pred)
mape_release = mean_absolute_percentage_error(Y_release_, Y_release_pred)

# Print the results
print(f"Frequency - MSE: {mse_freq}, MAPE: {mape_freq}")
print(f"Attack - MSE: {mse_attack}, MAPE: {mape_attack}")
print(f"Release - MSE: {mse_release}, MAPE: {mape_release}")

In [None]:
latent_play_parameters["theta_freq"] = reg_vector_freq[0]
latent_play_parameters["theta_attack"] = reg_vector_attack[0]
latent_play_parameters["theta_release"] = reg_vector_release[0]
print(reg_vector_freq[0].shape, reg_vector_attack[0].shape, reg_vector_release[0].shape)

### Save data for the Plugin

In [None]:
#!cat data/config.yaml

In [None]:
# Save the dictionary of arrays to a file using pickle
with open(f'{DATA_SAVE}/latent_play_parameters.pkl', 'wb') as file:
    pickle.dump(latent_play_parameters, file)
    
vst_config = {
    'TRAINING': training_cfg,
    'MODEL': model_cfg,
    'DATA': data_cfg,
}

with open(f'{DATA_SAVE}/config.yaml', 'w') as file:
    yaml.dump(vst_config, file)

# Check the created files
import os
os.system('ls data')

## Control Latent Space

In [None]:
z_pca_1 = latent_play_parameters["z_pca_1"]
z_pca_2 = latent_play_parameters["z_pca_2"]
theta_freq = latent_play_parameters["theta_freq"]
theta_attack = latent_play_parameters["theta_attack"]
theta_release = latent_play_parameters["theta_release"]

A = np.stack([z_pca_1, z_pca_2, theta_freq, theta_attack, theta_release])
C = A.T @ np.linalg.inv(A @ A.T)

def control_latent(z,
                   latent_pca1,
                   latent_pca2,
                   target_freq,
                   target_attack,
                   target_release,
                   ):
    
    global A, C2
    
    # Make z satisfy the targets with minimum change according linear regression
    b = np.array([latent_pca1, latent_pca2, target_freq, target_attack, target_release]) # !!
    z_prim = z - C @ (A @ z - b)

    return z_prim

# Et voilà !