----

# <center><b>Audio ResNet - VAE</center>
# <center><b><span style="color:red;">Kick neural synthesis</b></span></center>
----
## Adhémar DE SENNEVILLE | adhemar.de_senneville@ens-paris-saclay.fr
-----

# Import and Utils

In [1]:
from math import floor
import numpy as np
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import os
import matplotlib.pyplot as plt

PATH = '../Dataset/kick_dataset'

# Data

In [None]:
from torch.utils.data import Dataset
import torchaudio
import torch

class KickDataset(Dataset):
    def __init__(self, path = "", sr = 44100, duration = 1):
                 
        self.path = path
        self.sr = sr
        self.duration = duration
        self.sample_length = int(sr * duration)
        self.wav_files = self._create_index_table()

    def _create_index_table(self):
        wav_files = []
        for root, _, files in os.walk(self.path):
            for file in files:
                if file.endswith('.wav'):
                    relative_path = os.path.join(root, file)
                    wav_files.append(relative_path)
        return wav_files
    
    def load_all(self, limit=None):
        limit = len(self) if limit is None else min(limit, len(self))
        self.data = torch.zeros((limit, 1, self.sample_length))

        for i in range(limit):
            self.data[i] = self.load_wav(self.wav_files[i])

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def load_wav(self, relative_path):
        # Load audio file
        waveform, sample_rate = torchaudio.load(relative_path, normalize=True)

        # Ensure mono by averaging channels if necessary
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if sample rate is different from cfg
        if sample_rate != self.sr:
            transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)
            waveform = transform(waveform)

        # Normalize the waveform
        waveform = waveform / waveform.abs().max()

        # Pad or trim to the desired sample length
        if waveform.size(1) < self.sample_length:
            pad_size = self.sample_length - waveform.size(1)
            waveform = torch.nn.functional.pad(waveform, (0, pad_size))
        elif waveform.size(1) > self.sample_length:
            waveform = waveform[:, :self.sample_length]

        return waveform

    def show_sample(self, idx):
        waveform = self.load_wav(self.wav_files[idx])
        print("Name  :",self.wav_files[idx])
        print("Shape :",waveform.shape)
        plt.figure(figsize=(10, 4))
        plt.plot(waveform.t().numpy())
        plt.title(f"Waveform of sample at index {idx}")
        plt.xlabel("Time")
        plt.ylabel("Amplitude")
        plt.show()

data_cfg = {
    'path': PATH,
    'sr': 22050,
    'duration': 0.3, # seconds
}

# Create dataset instance
dataset = KickDataset(**data_cfg)

# Model

In [5]:
import pytorch_lightning as pl
from model import AutoEncoder1d
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


class LitAutoEncoder(pl.LightningModule):
    def __init__(self, model_cfg, training_cfg):
        super(LitAutoEncoder, self).__init__()

        # Model initialization
        self.model = AutoEncoder1d(
            in_channels=model_cfg['in_channels'],
            channels=model_cfg['channels'],
            multipliers=model_cfg['multipliers'],
            factors=model_cfg['factors'],
            num_blocks=model_cfg['num_blocks']
        )

        # Training 
        self.lr = training_cfg['lr']
        self.loss = training_cfg['loss']

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, x, batch_idx):
        
        x_hat = self.forward(x)
        loss = self.loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)



In [None]:
# Configuration dictionaries
model_cfg = {
    'in_channels': 2,
    'channels': 32,
    'multipliers': [1, 1, 2, 2],
    'factors': [4, 4, 4],
    'num_blocks': [2, 2, 2]
}


# Training

In [None]:
training_cfg = {
    'batch_size': 32,
    'max_epochs': 10,
    'lr': 1e-3
}


# Initialize model
autoencoder = LitAutoEncoder(model_cfg)

dataloader = DataLoader(dataset, batch_size=training_cfg['batch_size'], shuffle=True)

# Initialize trainer
trainer = pl.Trainer(max_epochs=training_cfg['max_epochs'], gpus=1 if torch.cuda.is_available() else 0)

# Train model
trainer.fit(autoencoder, dataloader)

----
# Lattente space analysis

In [None]:
Z = np.array(torch.cat([autoencoder.model.encoder(x.unsqueeze(0)) for x in dataset]))

## PCA

## Link High level features and latent space

### Get Features 
Here we compute some easy to compute high level features

In [None]:
def get_max_frequency(x, sr):
    with torch.no_grad():
        fft_result = torch.fft.fft(x)
        power_spectrum = torch.abs(fft_result) ** 2
        
        # Get the frequency bin with the highest power
        max_power_index = torch.argmax(power_spectrum)
        max_frequency = max_power_index * sr / x.size(-1)
        
        return (max_frequency).item()

def get_attack(x):
    attack_index = torch.argmax(x.abs())
    # Time 100, because values will be low
    return (100*attack_index/x.size(-1)).item()

def get_release(x):
    threshold = 0.1
    indices_above_threshold = torch.where(x.abs() > threshold)[0]
    if len(indices_above_threshold) > 0:
        release_index = indices_above_threshold[-1]
        return (release_index / x.size(-1)).item()
    else:
        return 0  # If no samples are above the threshold (souldn't be possible)

In [None]:
Y_freq = np.array([get_max_frequency(x,dataset.sr) for x in dataset])
Y_attack = np.array([get_attack(x) for x in dataset])
Y_release = np.array([get_release(x) for x in dataset])

### Linear regrassion

In [None]:
from sklearn.linear_model import LinearRegression

# Convert tensors to numpy arrays
Y_freq = Y_freq.reshape(-1, 1)
Y_attack = Y_attack.reshape(-1, 1)
Y_release = Y_release.reshape(-1, 1)

# Perform linear regression for each parameter
regressor_freq = LinearRegression().fit(Z, Y_freq)
regressor_attack = LinearRegression().fit(Z, Y_attack)
regressor_release = LinearRegression().fit(Z, Y_release)

# Get regression vectors
reg_vector_freq = regressor_freq.coef_
reg_vector_attack = regressor_attack.coef_
reg_vector_release = regressor_release.coef_

## Control from latent space

In [None]:
A = np.stack([reg_vector_freq, reg_vector_attack, reg_vector_release])


def control_latent(z,
                   target_freq,
                   target_attack,
                   target_release,
                   latent_pca1,
                   latent_pca2,
                   ):
    
    global A, pca1, pca2
    
    # Change main varience in latente space
    z = z + pca1 * latent_pca1 + pca2 * latent_pca2
    
    # Make z satisfy the targets with minimum change according linear regression
    b = np.array([target_freq, target_attack, target_release]) # !!
    z_prim = z - A.T @ np.linalg.inv(A @ A.T) @ (A @ z - b)

    return z_prim
    



## Attack