In [None]:
# %%

import torch.multiprocessing as mp
import torchsummary as summary
import librosa.display
import librosa
import matplotlib.pyplot as plt
import soundfile as sf
import os
import numpy as np
import pandas as pd
import numba as nb
import dask as dk
import joblib as jl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from data.config import *
from data.utils import *
# from data.dataset import MixtureDataset, AudioMixtureDataset
from data.dataset import AudioMixtureDatasetWithLoudnorm
from tqdm import tqdm
from torchlibrosa.stft import STFT, ISTFT, magphase
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

# mp.set_start_method('spawn', force=True)

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [None]:
# %%

class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super(ResidualBlock, self).__init__()

        self.residual_block = nn.Sequential(
            nn.BatchNorm2d(in_c),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_c, out_c,
                      kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_c, out_c,
                      kernel_size=3, stride=1, padding=1),
        )

        """ Shortcut Connection """
        self.shortcut = nn.Conv2d(
            in_c, out_c, kernel_size=1, stride=stride, padding=0)

    def forward(self, inputs):
        x = self.residual_block(inputs)
        s = self.shortcut(inputs)

        skip = x + s
        return skip


class DecoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(DecoderBlock, self).__init__()
        self.upsampling = nn.ConvTranspose2d(
            in_c, out_c, kernel_size=2, stride=2, padding=0, dilation=1)
        self.residual_block = ResidualBlock(
            out_c * 2, out_c)
        # self.upsampling = nn.Upsample(
        #     scale_factor=2, mode='bilinear', align_corners=True)
        # self.residual_block = ResidualBlock(
        #     in_c + out_c, out_c)

    def forward(self, x, skip):
        # Upsample
        x = self.upsampling(x)
        # Ensure x and skip have the same spatial dimensions
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(
                x, size=(skip.shape[2], skip.shape[3]), mode='bilinear', align_corners=True)

        # Concatenate
        x = torch.cat([x, skip], dim=1)

        # Residual block
        x = self.residual_block(x)

        return x


class ResUNet(nn.Module):
    def __init__(self, in_c, out_c):
        super(ResUNet, self).__init__()

        """ Encoder 1 """
        self.encoder_block1 = nn.Sequential(
            nn.Conv2d(in_c, out_c,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_c, out_c,
                      kernel_size=3, stride=1, padding=1),
        )

        """ Shortcut Connection """
        self.shortcut = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)

        """ Encoder 2 and 3"""
        self.encoder_block2 = ResidualBlock(
            out_c, out_c * 2, stride=2)
        self.encoder_block3 = ResidualBlock(
            out_c * 2, out_c * 4, stride=2)

        """ Bridge """
        self.bridge = ResidualBlock(
            out_c * 4, out_c * 8, stride=2)

        """ Decoder """
        self.decoder_block1 = DecoderBlock(out_c * 8, out_c * 4)
        self.decoder_block2 = DecoderBlock(out_c * 4, out_c * 2)
        self.decoder_block3 = DecoderBlock(out_c * 2, out_c)

        """ Output """
        self.output = nn.Sequential(
            nn.Conv2d(out_c, 3, kernel_size=1, padding=0),
        )

    def forward(self, inputs):

        inputs = inputs.unsqueeze(1)

        """ Encoder 1 """
        encoder1 = self.encoder_block1(inputs)
        s = self.shortcut(inputs)
        skip1 = encoder1 + s

        """ Encoder 2 and 3 """
        skip2 = self.encoder_block2(skip1)
        skip3 = self.encoder_block3(skip2)

        """ Bridge """
        bridge = self.bridge(skip3)

        """ Decoder """
        decoder1 = self.decoder_block1(bridge, skip3)
        decoder2 = self.decoder_block2(decoder1, skip2)
        decoder3 = self.decoder_block3(decoder2, skip1)

        """ Output """
        output = self.output(decoder3)

        # return output, skip3

        output_masks_dict = {
            'mag_mask': torch.sigmoid(output[:, 0, :, :]),
            'real_mask': torch.tanh(output[:, 1, :, :]),
            'imag_mask': torch.tanh(output[:, 2, :, :])
        }

        return output_masks_dict, skip3

In [None]:
# %%

class MultiTaskResUNet(nn.Module):
    def __init__(self, num_noise_classes):
        super().__init__()
        self.resunet = ResUNet(in_c=1, out_c=32)

        # self.classifier = nn.Sequential(
        #     nn.Conv2d(128, 64, kernel_size=3, padding=1),
        #     nn.BatchNorm2d(64),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),  # Add max pooling here
        #     nn.Dropout(0.3),
        #     nn.Conv2d(64, 32, kernel_size=3, padding=1),
        #     nn.BatchNorm2d(32),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),  # Add max pooling here
        #     nn.Dropout(0.3),
        # )

        # # output classifier
        # self.classifier_output = nn.Sequential(
        #     nn.Linear(32 * 8 * 30, 64),
        #     nn.ReLU(),
        #     nn.Dropout(0.3),
        #     nn.Linear(64, num_noise_classes),  # Corrected the input size to 64
        # )

        # Classification head
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_noise_classes),
            # nn.Linear(128, num_noise_classes),
            nn.Sigmoid()
        )

    def forward(self, x):

        output, skip3 = self.resunet(x)

        # x = self.classifier(skip3)
        # x = torch.flatten(x, start_dim=1)
        # x = self.classifier_output(x)

        x = self.classifier(skip3)

        return output, x

In [None]:
# %%

# Define the multi-task loss function
def multi_task_loss(separation_output, classification_output, true_percussion, true_class, alpha=0.7, beta=0.3):
    mse_loss = nn.MSELoss()

    separation_loss = mse_loss(separation_output, true_percussion)
    # classification_loss = nn.CrossEntropyLoss()(classification_output, true_class) 1 ere version
    # classification_loss = nn.BCELoss()(classification_output, F.one_hot(true_class, num_classes=8).float()) #2 eme version
    
    classification_loss = nn.BCEWithLogitsLoss()(classification_output, true_class)  # 3 eme version
    # classification_loss = nn.BCEWithLogitsLoss()(classification_output, F.one_hot(true_class, num_classes=8).float()) #3 eme version

    loss = alpha * separation_loss + beta * classification_loss

    return loss

In [None]:
# %%

# Load metadata
metadata = pd.read_csv(os.path.join(
    DATASET_MIX_AUDIO_PATH, "metadata.csv"))

# define the train, validation and test sets

# dataset = MixtureDataset(metadata_file=metadata, k=0.6,
#                          noise_class=None)
# dataset = AudioMixtureDataset(metadata_file=metadata, k=0.4,
#                               noise_class='siren')
# dataset = AudioMixtureDataset(metadata_file=metadata, k=None, noise_class=None)

dataset = AudioMixtureDatasetWithLoudnorm(metadata_file=metadata, noise_classes=[
                                          'engine_idling', 'air_conditioner'], random_noise=True)

In [None]:
# %%

# when using the saved indices
# train_indices = np.load('train_indices.npy')
# val_indices = np.load('val_indices.npy')
# test_indices = np.load('test_indices.npy')
train_indices = np.load('train_indices_engine_air.npy')
val_indices = np.load('val_indices_engine_air.npy')
test_indices = np.load('test_indices_engine_air.npy')

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

# train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=64, num_workers=2, persistent_workers=True, prefetch_factor=2)
# val_loader = DataLoader(dataset, sampler=val_sampler,
#                         batch_size=128, num_workers=2, persistent_workers=True, prefetch_factor=2)
# test_loader = DataLoader(dataset, sampler=test_sampler,
#                          batch_size=128, num_workers=2, persistent_workers=True, prefetch_factor=2)

train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=32)
val_loader = DataLoader(dataset, sampler=val_sampler, batch_size=32)
test_loader = DataLoader(dataset, sampler=test_sampler, batch_size=32)

# data = next(iter(train_loader))

In [None]:
#%%

class SpectrogramReconstructor:
    def __init__(self):
        pass

    def magphase(self, real, imag):
        mag = (real ** 2 + imag ** 2) ** 0.5
        cos = real / torch.clamp(mag, 1e-10, np.inf)
        sin = imag / torch.clamp(mag, 1e-10, np.inf)
        
        return mag, cos, sin

    def reconstruct(self, mag_mask, real_mask, imag_mask, mix_stft):
        
        mix_mag, mix_cos, mix_sin = self.magphase(mix_stft.real, mix_stft.imag)
        _, mask_cos, mask_sin = self.magphase(real_mask, imag_mask)
        
        # calculate the |Y| = |M| * |X|
        estimated_mag = mag_mask * mix_mag

        # Reconstruct the complex spectrogram
        Y_real = estimated_mag * (mask_cos * mix_cos - mask_sin * mix_sin)
        Y_imag = estimated_mag * (mask_cos * mix_sin + mask_sin * mix_cos)
        Y_complex = torch.complex(Y_real, Y_imag)

        return Y_complex


# ISTFT conversion function


def istft(y_complex, n_fft, hop_length):

    y = torch.istft(
        y_complex, n_fft, hop_length, window=torch.hann_window(256, device='cuda'), length=31248)

    return y

In [None]:
# %%

# Define the model, optimizer and loss function
model = MultiTaskResUNet(num_noise_classes=8).to("cuda")
# optimizer = AdamW(model.parameters(), lr=0.001, amsgrad=True)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = multi_task_loss
device = "cuda"

In [None]:
# %%

# Train the model

train_losses = []
val_losses = []
best_val_loss = np.inf
patience = 5
num_epochs = 10

# model, optimizer, start_epoch, loss = load_checkpoint(model, optimizer)

start_epoch = 0

for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    train_bar = tqdm(train_loader, desc=f"Epoch {
                     epoch + 1}/{num_epochs} Training Loss: {train_loss:.4f}", colour='green')
    for i, batch in enumerate(train_bar):
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Move data to device
        mixture = batch['mixture_audio'].to(device)
        true_percussion = batch['percussion_audio'].to(device)
        # true_class = batch['noise_class'].to(device)
        # ici true class est un tensor de taille (batch_size, 8) avec des 0 et des 1 pour les classes présentes et absentes
        true_class = batch['noise_labels'].to(device)

        # Calculate real and imaginary parts of the mixture
        mix_stft = torch.stft(mixture, n_fft=256, hop_length=64, win_length=256, window=torch.hann_window(
            window_length=256, device=device), return_complex=True)

        # Forward pass
        output, class_output = model(torch.abs(mix_stft))

        # mag_mask = torch.sigmoid(output[:, 0, :, :])
        # real_mask = torch.tanh(output[:, 1, :, :])
        # imag_mask = torch.tanh(output[:, 2, :, :])
        # ^^^ output is already a dictionary with keys mag_mask, real_mask, imag_mask

        # Reconstruct the complex spectrogram
        Y_complex = SpectrogramReconstructor().reconstruct(
            output['mag_mask'], output['real_mask'], output['imag_mask'], mix_stft)
        percussion_sep = istft(Y_complex, n_fft=256, hop_length=64)
        
        # Calculate the classification accuracy
        # _, predicted = torch.max(class_output, 1) predicted est un tensor de taille (btach_size) avec les indices correspondant aux classes prédites : 0, 1, 2, 3, 4, 5, 6, 7
        # ca ne marche pas car les deux tensors n'ont pas la même taille
        # _, predicted = torch.max(class_output, 1)
        predicted = (class_output > 0.5).float()
        # total += true_class.size(0)
        # total += true_class.size(0) * true_class.size(1)  # Since it's multi-label, count total elements
        # correct += (predicted == true_class).sum().item()  # Compare predicted and true labels
        # correct = (predicted == true_class).sum().item()  # Somme des classes correctement prédites
        # Somme des prédictions correctes
        correct += (predicted == true_class).float().sum().item()
        # total = true_class.numel()  # Nombre total d'éléments dans la matrice multi-label
        total += true_class.numel()  # total doit compter tous les éléments dans true_class

        # Calculate the loss
        loss = criterion(percussion_sep, class_output,
                         true_percussion, true_class)

        # Backward pass
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        train_bar.set_description(
            f"Epoch {epoch + 1}/{num_epochs} Training Loss: {train_loss/(i+1):.4f}")

    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    accuracy = correct / total
    
    # Validation
    model.eval()
    val_loss = 0
    total = 0
    correct = 0

    val_bar = tqdm(val_loader, desc=f"Epoch {
                   epoch + 1}/{num_epochs} Validation Loss: {val_loss:.4f}", colour='red')
    with torch.no_grad():
        for i, batch in enumerate(val_bar):
            # Move data to device
            mixture = batch['mixture_audio'].to(device)
            true_percussion = batch['percussion_audio'].to(device)
            true_class = batch['noise_labels'].to(device)

            # Calculate real and imaginary parts of the mixture
            mix_stft = torch.stft(mixture, n_fft=256, hop_length=64, win_length=256, window=torch.hann_window(
                window_length=256, device=device), return_complex=True)

            # Forward pass
            output, class_output = model(torch.abs(mix_stft))

            # mag_mask = torch.sigmoid(output[:, 0, :, :])
            # real_mask = torch.tanh(output[:, 1, :, :])
            # imag_mask = torch.tanh(output[:, 2, :, :])
            # ^^^ output is already a dictionary with keys mag_mask, real_mask, imag_mask

            # Reconstruct the complex spectrogram
            Y_complex = SpectrogramReconstructor().reconstruct(
                output['mag_mask'], output['real_mask'], output['imag_mask'], mix_stft)
            percussion_sep = istft(Y_complex, n_fft=256, hop_length=64)

            # Calculate the classification accuracy
            # _, predicted = torch.max(class_output, 1)
            predicted = (class_output > 0.5).float()
            # total += true_class.size(0)
            # correct += (predicted == true_class).sum().item()
            # correct = (predicted == true_class).sum().item
            # total = true_class.numel()
            correct += (predicted == true_class).float().sum().item()
            total += true_class.numel()

            # Calculate the loss
            loss = criterion(percussion_sep, class_output,
                             true_percussion, true_class)

            val_loss += loss.item()
            val_bar.set_description(
                f"Epoch {epoch + 1}/{num_epochs} Validation Loss: {val_loss/(i+1):.4f}")

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    val_accuracy = correct / total

    print(f"Epoch {epoch + 1}/{num_epochs} Training Loss: {train_loss:.4f}, Training Accuracy: {
          accuracy:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

    # Save checkpoint at the end of each epoch or based on some condition
    save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_dir='checkpoint',
                    filename='checkpoint_air_engine_epoch_{}.pth'.format(epoch + 1))

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience = 5
        torch.save(model.state_dict(), 'best_model_air_engine.pth')
        print("Model improved. Saving the model")
    else:
        patience -= 1
        if patience == 0:
            print("Early stopping")
            break

Epoch 1/10 Training Loss: 0.1994: 100%|[32m██████████[0m| 363/363 [05:09<00:00,  1.17it/s]
Epoch 1/10 Validation Loss: 0.1984: 100%|[31m██████████[0m| 121/121 [01:00<00:00,  2.00it/s]


Epoch 1/10 Training Loss: 0.1994, Training Accuracy: 0.9355, Validation Loss: 0.1984, Validation Accuracy: 0.9382
Checkpoint saved at 'checkpoint\checkpoint_air_engine_epoch_1.pth'
Model improved. Saving the model


Epoch 2/10 Training Loss: 0.1988: 100%|[32m██████████[0m| 363/363 [06:54<00:00,  1.14s/it]
Epoch 2/10 Validation Loss: 0.1985: 100%|[31m██████████[0m| 121/121 [00:58<00:00,  2.08it/s]


Epoch 2/10 Training Loss: 0.1988, Training Accuracy: 0.9370, Validation Loss: 0.1985, Validation Accuracy: 0.9377
Checkpoint saved at 'checkpoint\checkpoint_air_engine_epoch_2.pth'


Epoch 3/10 Training Loss: 0.1986: 100%|[32m██████████[0m| 363/363 [06:59<00:00,  1.16s/it]
Epoch 3/10 Validation Loss: 0.1984: 100%|[31m██████████[0m| 121/121 [00:58<00:00,  2.06it/s]


Epoch 3/10 Training Loss: 0.1986, Training Accuracy: 0.9378, Validation Loss: 0.1984, Validation Accuracy: 0.9382
Checkpoint saved at 'checkpoint\checkpoint_air_engine_epoch_3.pth'


Epoch 4/10 Training Loss: 0.1964:   2%|[32m▏         [0m| 8/363 [00:09<07:15,  1.23s/it]


KeyboardInterrupt: 