# Speech Denoising without Clean Training Data: a Noise2Noise Approach #

RETE DEL PAPER ORIGINALE CORRETTA, USATA PER OTTENERE LE VALUTAZIONI CON IL NOSTRO NUOVO DATASET DI CANZONI

Link ai pesi originali della rete preallenata: https://www.kaggle.com/datasets/frads01/pretrainedweights

Link al dataset originale speech: https://www.kaggle.com/datasets/frads01/refdataset

### Enter the noise type you want to train the model to denoise. The test and train dataset must already be generated beforehand. ###

### white : additive_gaussian_noise ###
### 0 : air_conditioner ###
### 1 : car_horn ###
### 2 : children_playing ###
### 3 : dog_bark ###
### 4 : drilling ###
### 5 : engine_idling ###
### 6 : gun_shot ###
### 7 : jackhammer ###
### 8 : siren ###
### 9 : street_music ###


In [None]:
noise_class = "white" 

### Specify the type of training you want to employ: either "Noise2Noise" or "Noise2Clean"  ###

In [None]:
training_type =  "Noise2Noise" 

### Import of libraries ###

In [None]:
from pathlib import Path

if noise_class == "white": 
    TRAIN_INPUT_DIR = Path('/kaggle/input/refdataset/Datasets/WhiteNoise_Train_Input')

    if training_type == "Noise2Noise":
        TRAIN_TARGET_DIR = Path('/kaggle/input/refdataset/Datasets/WhiteNoise_Train_Output')
    elif training_type == "Noise2Clean":
        TRAIN_TARGET_DIR = Path('/kaggle/input/refdataset/Datasets/clean_trainset_28spk_wav')
    else:
        raise Exception("Enter valid training type")

    TEST_NOISY_DIR = Path('/kaggle/input/refdataset/Datasets/WhiteNoise_Test_Input')
    TEST_CLEAN_DIR = Path('/kaggle/input/refdataset/Datasets/clean_testset_wav') 
    
else:
    TRAIN_INPUT_DIR = Path('/kaggle/input/refdataset/Datasets/US_Class'+str(noise_class)+'_Train_Input')

    if training_type == "Noise2Noise":
        TRAIN_TARGET_DIR = Path('/kaggle/input/refdataset/Datasets/US_Class'+str(noise_class)+'_Train_Output')
    elif training_type == "Noise2Clean":
        TRAIN_TARGET_DIR = Path('/kaggle/input/refdataset/Datasets/clean_trainset_28spk_wav')
    else:
        raise Exception("Enter valid training type")

    TEST_NOISY_DIR = Path('/kaggle/input/refdataset/Datasets/US_Class'+str(noise_class)+'_Test_Input')
    TEST_CLEAN_DIR = Path('/kaggle/input/refdataset/Datasets/clean_testset_wav') 

In [None]:
import os
basepath = str(noise_class)+"_"+training_type
os.makedirs(basepath,exist_ok=True)
os.makedirs(basepath+"/Weights",exist_ok=True)
os.makedirs(basepath+"/Samples",exist_ok=True)

In [None]:
import time
import pickle
import warnings
import gc
import copy

import numpy as np
import torch
import torch.nn as nn
import torchaudio

from tqdm import tqdm, tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from matplotlib import colors, pyplot as plt
from IPython.display import clear_output

%matplotlib inline

# not everything is smooth in sklearn, to conveniently output images in colab
# we will ignore warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)

In [None]:
np.random.seed(999)
torch.manual_seed(999)

# If running on Cuda set these 2 for determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Checking whether the GPU is available ###

In [None]:
# First checking if GPU is available
train_on_gpu=torch.cuda.is_available()

if(train_on_gpu):
    print('Training on GPU.')
else:
    print('No GPU available, training on CPU.')
       
DEVICE = torch.device('cuda' if train_on_gpu else 'cpu')

In [None]:
!nvidia-smi

### Set Audio backend as Soundfile for windows and Sox for Linux ###

In [None]:
torchaudio.set_audio_backend("soundfile")
print("TorchAudio backend used:\t{}".format(torchaudio.get_audio_backend()))

### The sampling frequency and the selected values for the Short-time Fourier transform. ###

In [None]:
SAMPLE_RATE = 48000
N_FFT = (SAMPLE_RATE * 64) // 1000 
HOP_LENGTH = (SAMPLE_RATE * 16) // 1000 

### The declaration of datasets and dataloaders ###

In [None]:
class SpeechDataset(Dataset):
    """
    A dataset class with audio that cuts them/paddes them to a specified length, applies a Short-tome Fourier transform,
    normalizes and leads to a tensor.
    """
    def __init__(self, noisy_files, clean_files, n_fft=64, hop_length=16):
        super().__init__()
        # list of files
        self.noisy_files = sorted(noisy_files)
        self.clean_files = sorted(clean_files)
        
        # stft parameters
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        self.len_ = len(self.noisy_files)
        
        # fixed len
        self.max_len = 165000

    
    def __len__(self):
        return self.len_
      
    def load_sample(self, file):
        waveform, _ = torchaudio.load(file)
        return waveform
  
    def __getitem__(self, index):
        # load to tensors and normalization
        x_clean = self.load_sample(self.clean_files[index])
        x_noisy = self.load_sample(self.noisy_files[index])
        
        # padding/cutting
        x_clean = self._prepare_sample(x_clean)
        x_noisy = self._prepare_sample(x_noisy)
        
        # Short-time Fourier transform - AGGIUNGERE return_complex=True
        x_noisy_stft = torch.stft(input=x_noisy, n_fft=self.n_fft, 
                             hop_length=self.hop_length, normalized=True, 
                             return_complex=True)
        x_clean_stft = torch.stft(input=x_clean, n_fft=self.n_fft, 
                             hop_length=self.hop_length, normalized=True, 
                             return_complex=True)
    
        # Convertire il tensore complesso in formato compatibile con il modello
        x_noisy_stft = torch.view_as_real(x_noisy_stft)
        x_clean_stft = torch.view_as_real(x_clean_stft)
        
        return x_noisy_stft, x_clean_stft
        
    def _prepare_sample(self, waveform):
        waveform = waveform.numpy()
        current_len = waveform.shape[1]
        
        output = np.zeros((1, self.max_len), dtype='float32')
        output[0, -current_len:] = waveform[0, :self.max_len]
        output = torch.from_numpy(output)
        
        return output

In [None]:
train_input_files = sorted(list(TRAIN_INPUT_DIR.rglob('*.wav')))
train_target_files = sorted(list(TRAIN_TARGET_DIR.rglob('*.wav')))

test_noisy_files = sorted(list(TEST_NOISY_DIR.rglob('*.wav')))
test_clean_files = sorted(list(TEST_CLEAN_DIR.rglob('*.wav')))

print("No. of Training files:",len(train_input_files))
print("No. of Testing files:",len(test_noisy_files))

In [None]:
test_dataset = SpeechDataset(test_noisy_files, test_clean_files, N_FFT, HOP_LENGTH)
train_dataset = SpeechDataset(train_input_files, train_target_files, N_FFT, HOP_LENGTH)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

# For testing purpose
test_loader_single_unshuffled = DataLoader(test_dataset, batch_size=1, shuffle=False)


### Average Test Set Metrics ###

In [None]:
def test_set_metrics(test_loader, model):
    metric_names = ["CSIG","CBAK","COVL","PESQ","SSNR","STOI"]
    overall_metrics = [[] for i in range(len(metric_names))]
    
    for i,(noisy,clean) in enumerate(test_loader):
        x_est = model(noisy.to(DEVICE), is_istft=True)
        x_est_np = x_est[0].view(-1).detach().cpu().numpy()
        x_c_np = torch.istft(torch.squeeze(clean[0], 1), n_fft=N_FFT, hop_length=HOP_LENGTH, normalized=True).view(-1).detach().cpu().numpy()
        metrics = AudioMetrics(x_c_np, x_est_np, SAMPLE_RATE)
        
        overall_metrics[0].append(metrics.CSIG)
        overall_metrics[1].append(metrics.CBAK)
        overall_metrics[2].append(metrics.COVL)
        overall_metrics[3].append(metrics.PESQ)
        overall_metrics[4].append(metrics.SSNR)
        overall_metrics[5].append(metrics.STOI)
    
    metrics_dict = dict()
    for i in range(len(metric_names)):
        metrics_dict[metric_names[i]] ={'mean': np.mean(overall_metrics[i]), 'std_dev': np.std(overall_metrics[i])} 
    
    return metrics_dict

### Declaring the class layers ###

In [None]:
class CConv2d(nn.Module):
    """
    Class of complex valued convolutional layer
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.real_conv = nn.Conv2d(in_channels=self.in_channels, 
                                   out_channels=self.out_channels, 
                                   kernel_size=self.kernel_size, 
                                   padding=self.padding, 
                                   stride=self.stride)
        
        self.im_conv = nn.Conv2d(in_channels=self.in_channels, 
                                 out_channels=self.out_channels, 
                                 kernel_size=self.kernel_size, 
                                 padding=self.padding, 
                                 stride=self.stride)
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_conv.weight)
        nn.init.xavier_uniform_(self.im_conv.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        c_real = self.real_conv(x_real) - self.im_conv(x_im)
        c_im = self.im_conv(x_real) + self.real_conv(x_im)
        
        output = torch.stack([c_real, c_im], dim=-1)
        return output

In [None]:
class CConvTranspose2d(nn.Module):
    """
      Class of complex valued dilation convolutional layer
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=0, padding=0):
        super().__init__()
        
        self.in_channels = in_channels

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_padding = output_padding
        self.padding = padding
        self.stride = stride
        
        self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)
        
        self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            output_padding=self.output_padding, 
                                            padding=self.padding,
                                            stride=self.stride)
        
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_convt.weight)
        nn.init.xavier_uniform_(self.im_convt.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        ct_real = self.real_convt(x_real) - self.im_convt(x_im)
        ct_im = self.im_convt(x_real) + self.real_convt(x_im)
        
        output = torch.stack([ct_real, ct_im], dim=-1)
        return output

In [None]:
class CBatchNorm2d(nn.Module):
    """
    Class of complex valued batch normalization layer
    """
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                      affine=self.affine, track_running_stats=self.track_running_stats)
        self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                    affine=self.affine, track_running_stats=self.track_running_stats) 
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        n_real = self.real_b(x_real)
        n_im = self.im_b(x_im)  
        
        output = torch.stack([n_real, n_im], dim=-1)
        return output

In [None]:
class Encoder(nn.Module):
    """
    Class of upsample block
    """
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0)):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding

        self.cconv = CConv2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, padding=self.padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        
        conved = self.cconv(x)
        normed = self.cbn(conved)
        acted = self.leaky_relu(normed)
        
        return acted

In [None]:
class Decoder(nn.Module):
    """
    Class of downsample block
    """
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45,
                 output_padding=(0,0), padding=(0,0), last_layer=False):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_padding = output_padding
        self.padding = padding
        
        self.last_layer = last_layer
        
        self.cconvt = CConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, output_padding=self.output_padding, padding=self.padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        
        conved = self.cconvt(x)
        
        if not self.last_layer:
            normed = self.cbn(conved)
            output = self.leaky_relu(normed)
        else:
            m_phase = conved / (torch.abs(conved) + 1e-8)
            m_mag = torch.tanh(torch.abs(conved))
            output = m_phase * m_mag
            
        return output

### Loss function ###

In [None]:
from scipy import interpolate

def resample(original, old_rate, new_rate):
    if old_rate != new_rate:
        duration = original.shape[0] / old_rate
        time_old  = np.linspace(0, duration, original.shape[0])
        time_new  = np.linspace(0, duration, int(original.shape[0] * new_rate / old_rate))
        interpolator = interpolate.interp1d(time_old, original.T)
        new_audio = interpolator(time_new).T
        return new_audio
    else:
        return original


def wsdr_fn(x_, y_pred_, y_true_, eps=1e-8):
    # to time-domain waveform
    y_true_ = torch.squeeze(y_true_, 1)
    if y_true_.shape[-1] == 2:  # formato [..., 2]
        y_true_complex = torch.view_as_complex(y_true_)
        y_true = torch.istft(y_true_complex, n_fft=N_FFT, hop_length=HOP_LENGTH, 
                             normalized=True, onesided=True)
    else:  # già complesso
        y_true = torch.istft(y_true_, n_fft=N_FFT, hop_length=HOP_LENGTH, 
                             normalized=True, onesided=True)
    x_ = torch.squeeze(x_, 1)
    x = safe_istft(x_, N_FFT, HOP_LENGTH)
    

    y_pred = y_pred_.flatten(1)
    y_true = y_true.flatten(1)
    x = x.flatten(1)


    def sdr_fn(true, pred, eps=1e-8):
        num = torch.sum(true * pred, dim=1)
        den = torch.norm(true, p=2, dim=1) * torch.norm(pred, p=2, dim=1)
        return -(num / (den + eps))

    # true and estimated noise
    z_true = x - y_true
    z_pred = x - y_pred

    a = torch.sum(y_true**2, dim=1) / (torch.sum(y_true**2, dim=1) + torch.sum(z_true**2, dim=1) + eps)
    wSDR = a * sdr_fn(y_true, y_pred) + (1 - a) * sdr_fn(z_true, z_pred)
    return torch.mean(wSDR)

wonky_samples = []

def getMetricsonLoader(loader, net, use_net=True):
    net.eval()
    # Original test metrics
    scale_factor = 32768
    # metric_names = ["CSIG","CBAK","COVL","PESQ","SSNR","STOI","SNR "]
    metric_names = ["PESQ-WB","PESQ-NB","SNR","SSNR","STOI"]
    overall_metrics = [[] for i in range(5)]
    for i, data in enumerate(loader):
        if (i+1)%10==0:
            end_str = "\n"
        else:
            end_str = ","
        #print(i,end=end_str)
        if i in wonky_samples:
            print("Something's up with this sample. Passing...")
        else:
            noisy = data[0]
            clean = data[1]
            if use_net: # Forward of net returns the istft version
                x_est = net(noisy.to(DEVICE), is_istft=True)
                x_est_np = x_est.view(-1).detach().cpu().numpy()
            else:
                x_est_np = torch.istft(torch.squeeze(noisy, 1), n_fft=N_FFT, hop_length=HOP_LENGTH, normalized=True, onesided=True).view(-1).detach().cpu().numpy()
            x_clean_np = torch.istft(torch.squeeze(clean, 1), n_fft=N_FFT, hop_length=HOP_LENGTH, normalized=True, onesided=True).view(-1).detach().cpu().numpy()
            
        
            metrics = AudioMetrics2(x_clean_np, x_est_np, 48000)
            
            ref_wb = resample(x_clean_np, 48000, 16000)
            deg_wb = resample(x_est_np, 48000, 16000)
            pesq_wb = pesq(16000, ref_wb, deg_wb, 'wb')
            
            ref_nb = resample(x_clean_np, 48000, 8000)
            deg_nb = resample(x_est_np, 48000, 8000)
            pesq_nb = pesq(8000, ref_nb, deg_nb, 'nb')

            #print(new_scores)
            #print(metrics.PESQ, metrics.STOI)

            overall_metrics[0].append(pesq_wb)
            overall_metrics[1].append(pesq_nb)
            overall_metrics[2].append(metrics.SNR)
            overall_metrics[3].append(metrics.SSNR)
            overall_metrics[4].append(metrics.STOI)
    print()
    print("Sample metrics computed")
    results = {}
    for i in range(5):
        temp = {}
        temp["Mean"] =  np.mean(overall_metrics[i])
        temp["STD"]  =  np.std(overall_metrics[i])
        temp["Min"]  =  min(overall_metrics[i])
        temp["Max"]  =  max(overall_metrics[i])
        results[metric_names[i]] = temp
    print("Averages computed")
    if use_net:
        addon = "(cleaned by model)"
    else:
        addon = "(pre denoising)"
    print("Metrics on test data",addon)
    for i in range(5):
        print("{} : {:.3f}+/-{:.3f}".format(metric_names[i], np.mean(overall_metrics[i]), np.std(overall_metrics[i])))
    return results

### Description of the training of epochs. ###

In [None]:
def train_epoch(net, train_loader, loss_fn, optimizer):
    net.train()
    train_ep_loss = 0.
    counter = 0
    for noisy_x, clean_x in train_loader:

        noisy_x, clean_x = noisy_x.to(DEVICE), clean_x.to(DEVICE)

        # zero  gradients
        net.zero_grad()

        # get the output from the model
        pred_x = net(noisy_x)

        # calculate loss
        loss = loss_fn(noisy_x, pred_x, clean_x)
        loss.backward()
        optimizer.step()

        train_ep_loss += loss.item() 
        counter += 1

    train_ep_loss /= counter

    # clear cache
    gc.collect()
    torch.cuda.empty_cache()
    return train_ep_loss

### Description of the validation of epochs ###

In [None]:
def test_epoch(net, test_loader, loss_fn, use_net=True):
    net.eval()
    test_ep_loss = 0.
    counter = 0.
    '''
    for noisy_x, clean_x in test_loader:
        # get the output from the model
        noisy_x, clean_x = noisy_x.to(DEVICE), clean_x.to(DEVICE)
        pred_x = net(noisy_x)

        # calculate loss
        loss = loss_fn(noisy_x, pred_x, clean_x)
        # Calc the metrics here
        test_ep_loss += loss.item() 
        
        counter += 1

    test_ep_loss /= counter
    '''
    
    #print("Actual compute done...testing now")
    
    testmet = getMetricsonLoader(test_loader,net,use_net)

    # clear cache
    gc.collect()
    torch.cuda.empty_cache()
    
    return test_ep_loss, testmet

### To understand whether the network is being trained or not, we will output a train and test loss. ###

In [None]:
def train(net, train_loader, test_loader, loss_fn, optimizer, scheduler, epochs):
    
    train_losses = []
    test_losses = []

    for e in tqdm(range(epochs)):

        # first evaluating for comparison
        
        if e == 0 and training_type=="Noise2Clean":
            print("Pre-training evaluation")
            #with torch.no_grad():
            #    test_loss,testmet = test_epoch(net, test_loader, loss_fn,use_net=False)
            #print("Had to load model.. checking if deets match")
            testmet = getMetricsonLoader(test_loader,net,False)    # again, modified cuz im loading
            #test_losses.append(test_loss)
            #print("Loss before training:{:.6f}".format(test_loss))
        
            with open(basepath + "/results.txt","w+") as f:
                f.write("Initial : \n")
                f.write(str(testmet))
                f.write("\n")
        
        
        train_loss = train_epoch(net, train_loader, loss_fn, optimizer)
        test_loss = 0
        scheduler.step()
        print("Saving model....")
        
        with torch.no_grad():
            test_loss, testmet = test_epoch(net, test_loader, loss_fn,use_net=True)

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        #print("skipping testing cuz peak autism idk")
        
        with open(basepath + "/results.txt","a") as f:
            f.write("Epoch :"+str(e+1) + "\n" + str(testmet))
            f.write("\n")
        
        print("OPed to txt")
        
        torch.save(net.state_dict(), basepath +'/Weights/dc20_model_'+str(e+1)+'.pth')
        torch.save(optimizer.state_dict(), basepath+'/Weights/dc20_opt_'+str(e+1)+'.pth')
        
        print("Models saved")

        # clear cache
        torch.cuda.empty_cache()
        gc.collect()

        #print("Epoch: {}/{}...".format(e+1, epochs),
        #              "Loss: {:.6f}...".format(train_loss),
        #              "Test Loss: {:.6f}".format(test_loss))
    return train_loss, test_loss

### 20 Layer DCUNet Model ###

In [None]:
class DCUnet20(nn.Module):
    """
    Deep Complex U-Net class of the model.
    """
    def __init__(self, n_fft=64, hop_length=16):
        super().__init__()
        
        # for istft
        self.n_fft = n_fft
        self.hop_length = hop_length
        
        self.set_size(model_complexity=int(45//1.414), input_channels=1, model_depth=20)
        self.encoders = []
        self.model_length = 20 // 2
        
        for i in range(self.model_length):
            module = Encoder(in_channels=self.enc_channels[i], out_channels=self.enc_channels[i + 1],
                             filter_size=self.enc_kernel_sizes[i], stride_size=self.enc_strides[i], padding=self.enc_paddings[i])
            self.add_module("encoder{}".format(i), module)
            self.encoders.append(module)

        self.decoders = []

        for i in range(self.model_length):
            if i != self.model_length - 1:
                module = Decoder(in_channels=self.dec_channels[i] + self.enc_channels[self.model_length - i], out_channels=self.dec_channels[i + 1], 
                                 filter_size=self.dec_kernel_sizes[i], stride_size=self.dec_strides[i], padding=self.dec_paddings[i],
                                 output_padding=self.dec_output_padding[i])
            else:
                module = Decoder(in_channels=self.dec_channels[i] + self.enc_channels[self.model_length - i], out_channels=self.dec_channels[i + 1], 
                                 filter_size=self.dec_kernel_sizes[i], stride_size=self.dec_strides[i], padding=self.dec_paddings[i],
                                 output_padding=self.dec_output_padding[i], last_layer=True)
            self.add_module("decoder{}".format(i), module)
            self.decoders.append(module)
       
        
    def forward(self, x, is_istft=True):
        # AGGIUNGERE: Assicurarsi che tutti i tensori siano sullo stesso dispositivo
        device = next(self.parameters()).device
        x = x.to(device)
        # print('x : ', x.shape)
        orig_x = x
        xs = []
        for i, encoder in enumerate(self.encoders):
            xs.append(x)
            x = encoder(x)
            # print('Encoder : ', x.shape)
            
        p = x
        for i, decoder in enumerate(self.decoders):
            p = decoder(p)
            if i == self.model_length - 1:
                break
            # print('Decoder : ', p.shape)
            p = torch.cat([p, xs[self.model_length - 1 - i]], dim=1)
        
        # u9 - the mask
        
        mask = p
        
        # print('mask : ', mask.shape)
        
        output = mask * orig_x
        output = torch.squeeze(output, 1)


        if is_istft:
            # CONVERSIONE NECESSARIA: da formato [..., 2] a tensore complesso nativo
            output_complex = torch.view_as_complex(output)
            output = torch.istft(output_complex, n_fft=self.n_fft, hop_length=self.hop_length, normalized=True, onesided=True)
        
        return output

    
    def set_size(self, model_complexity, model_depth=20, input_channels=1):

        if model_depth == 20:
            self.enc_channels = [input_channels,
                                 model_complexity,
                                 model_complexity,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 128]

            self.enc_kernel_sizes = [(7, 1),
                                     (1, 7),
                                     (6, 4),
                                     (7, 5),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3)]

            self.enc_strides = [(1, 1),
                                (1, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1)]

            self.enc_paddings = [(3, 0),
                                 (0, 3),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0)]

            self.dec_channels = [0,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity,
                                 model_complexity,
                                 1]

            self.dec_kernel_sizes = [(6, 3), 
                                     (6, 3),
                                     (6, 3),
                                     (6, 4),
                                     (6, 3),
                                     (6, 4),
                                     (8, 5),
                                     (7, 5),
                                     (1, 7),
                                     (7, 1)]

            self.dec_strides = [(2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (1, 1),
                                (1, 1)]

            self.dec_paddings = [(0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 3),
                                 (3, 0)]
            
            self.dec_output_padding = [(0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0)]
        else:
            raise ValueError("Unknown model depth : {}".format(model_depth))

## Training New Model ##

In [None]:
# # clear cache
gc.collect()
torch.cuda.empty_cache()

dcunet20 = DCUnet20(N_FFT, HOP_LENGTH).to(DEVICE)
optimizer = torch.optim.Adam(dcunet20.parameters())
loss_fn = wsdr_fn
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

In [None]:
# specify paths and uncomment to resume training from a given point
# model_checkpoint = torch.load(path_to_model)
# opt_checkpoint = torch.load(path_to_opt)
# dcunet20.load_state_dict(model_checkpoint)
# optimizer.load_state_dict(opt_checkpoint)

In [None]:
#train_losses, test_losses = train(dcunet20, train_loader, test_loader, loss_fn, optimizer, scheduler, 4)

## Using pretrained weights to run denoising inference ##

#### Select the model weight .pth file ####

In [None]:
model_weights_path = "/kaggle/input/pretrainedweights/Pretrained_Weights/Noise2Noise/1.pth"

dcunet20 = DCUnet20(N_FFT, HOP_LENGTH).to(DEVICE)
optimizer = torch.optim.Adam(dcunet20.parameters())

checkpoint = torch.load(model_weights_path,
                        map_location=torch.device('cpu')
                       )

#### Select the testing audio folders for inference ####

In [None]:
test_noisy_files = sorted(list(Path("/kaggle/input/white_noise_def/filtered_noise2noise_db_white_noise/test/input").rglob('*.wav')))
test_clean_files = sorted(list(Path("/kaggle/input/white_noise_def/filtered_noise2noise_db_white_noise/test/target").rglob('*.wav')))

test_dataset = SpeechDataset(test_noisy_files, test_clean_files, N_FFT, HOP_LENGTH)

# For testing purpose
test_loader_single_unshuffled = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
dcunet20.load_state_dict(checkpoint)

#### Enter the index of the file in the Test Set folder to Denoise and evaluate metrics waveforms (Indexing starts from 0) ####

In [None]:
index = 2

In [None]:
dcunet20.eval()
test_loader_single_unshuffled_iter = iter(test_loader_single_unshuffled)

x_n, x_c = next(test_loader_single_unshuffled_iter)
for _ in range(index):
    x_n, x_c = next(test_loader_single_unshuffled_iter)

x_est = dcunet20(x_n, is_istft=True)

In [None]:
def safe_istft(stft_tensor, n_fft, hop_length, normalized=True):
    """
    Sicura chiamata a torch.istft che gestisce automaticamente la conversione da real/imag a complesso
    """
    if stft_tensor.shape[-1] == 2 and not torch.is_complex(stft_tensor):
        # Converte da formato (real, imag) a complesso
        stft_tensor = torch.view_as_complex(stft_tensor)
    
    return torch.istft(stft_tensor, n_fft=n_fft, hop_length=hop_length, normalized=normalized)

In [None]:
x_est_np = x_est[0].view(-1).detach().cpu().numpy()
x_c_np = safe_istft(torch.squeeze(x_c, 1), N_FFT, HOP_LENGTH).view(-1).detach().cpu().numpy()
x_n_np = safe_istft(torch.squeeze(x_n, 1), N_FFT, HOP_LENGTH).view(-1).detach().cpu().numpy()

#### Metrics ####

In [None]:
# === SNR Function ===
def compute_snr(clean, estimate):
    # Converti in tensori se sono array numpy
    if isinstance(clean, np.ndarray):
        clean = torch.tensor(clean, dtype=torch.float32)
    if isinstance(estimate, np.ndarray):
        estimate = torch.tensor(estimate, dtype=torch.float32)
    
    signal_power = torch.sum(clean ** 2)
    noise_power = torch.sum((clean - estimate) ** 2)
    
    # Evita log(0) con epsilon più robusto
    eps = torch.finfo(clean.dtype).eps
    snr = 10 * torch.log10(signal_power / (noise_power + eps))
    return snr.item()


#### Visualization of denoising the audio in /Samples folder ####

#### Noisy audio waveform ####

In [None]:
plt.plot(x_n_np)

#### Model denoised audio waveform ####

In [None]:
plt.plot(x_est_np)

#### True clean audio waveform ####

In [None]:
plt.plot(x_c_np)

#### Save Recently Denoised Speech Files ####

In [None]:
import soundfile as sf

# Salva direttamente gli array numpy
sf.write("/kaggle/working/predicted.wav", x_est_np, SAMPLE_RATE)
sf.write("/kaggle/working/clean.wav", x_c_np, SAMPLE_RATE) 
sf.write("/kaggle/working/noisy.wav", x_n_np, SAMPLE_RATE)


In [None]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

# Device e parametri globali
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class AudioProcessor:
    """Classe per gestire elaborazione audio"""
    
    def __init__(self, n_fft, hop_length, sample_rate):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.sample_rate = sample_rate
        self.window = torch.hann_window(n_fft).to(device)

    def stft_to_waveform(self, stft_tensor):
        """Conversione STFT->waveform"""
        if stft_tensor.dim() == 5:
            stft_tensor = stft_tensor[0, 0]
        elif stft_tensor.dim() == 4:
            stft_tensor = stft_tensor[0]
        
        complex_tensor = torch.view_as_complex(stft_tensor)
        waveform = torch.istft(
            complex_tensor,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            window=self.window,
            normalized=True,
            return_complex=False
        )
        return waveform

    def ensure_waveform(self, tensor):
        """Converte tensor a waveform"""
        if tensor.dim() >= 3 and tensor.size(-1) == 2:
            return self.stft_to_waveform(tensor)
        else:
            while tensor.dim() > 1 and tensor.size(0) == 1:
                tensor = tensor.squeeze(0)
            return tensor

class SingleSamplePlotter:
    def __init__(self, sample_rate, save_dir='/kaggle/working'):
        self.sample_rate = sample_rate
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
    def plot(self, clean_wave, noisy_wave, predicted_wave, sample_idx, snr_value, snr_improvement):
        def to_numpy(tensor):
            if torch.is_tensor(tensor):
                return tensor.detach().cpu().numpy().squeeze()
            return np.asarray(tensor).squeeze()
        try:
            clean_np, noisy_np, pred_np = map(to_numpy, [clean_wave, noisy_wave, predicted_wave])
            min_len = min(len(clean_np), len(noisy_np), len(pred_np))
            clean_np = clean_np[:min_len]
            noisy_np = noisy_np[:min_len]
            pred_np = pred_np[:min_len]
            time_axis = np.arange(min_len) / self.sample_rate
            fig, axes = plt.subplots(3, 1, figsize=(12, 8))
            colors = ['#2E8B57', '#DC143C', '#1E90FF']
            signals = [('Clean (Ref)', clean_np), ('Noisy (Input)', noisy_np), ('Denoised (Output)', pred_np)]
            for i, (label, signal) in enumerate(signals):
                axes[i].plot(time_axis, signal, color=colors[i], linewidth=0.8, alpha=0.9)
                if i == 2:
                    title = f'{label} - SNR: {snr_value:.2f} dB | ΔSNR: {snr_improvement:+.2f} dB'
                else:
                    title = f'Sample {sample_idx} - {label}'
                axes[i].set_title(title, fontsize=11)
                axes[i].set_ylabel('Amplitude')
                axes[i].grid(True, alpha=0.3)
            axes[-1].set_xlabel('Time (seconds)')
            plt.tight_layout()
            filename = self.save_dir / f'sample_{sample_idx:03d}_plot.png'
            plt.savefig(filename, dpi=120, bbox_inches='tight', facecolor='white')
            plt.close()
            return str(filename)
        except Exception as e:
            print(f" Errore plot campione {sample_idx}: {e}")
            return None


class SNRCalculator:
    """Classe per calcoli SNR"""
    
    @staticmethod
    def compute_snr(clean, estimate, epsilon=1e-10):
        signal_power = torch.sum(clean ** 2)
        noise_power = torch.sum((clean - estimate) ** 2)
        noise_power = torch.clamp(noise_power, min=epsilon)
        snr = 10 * torch.log10(signal_power / noise_power)
        return snr.item()

    @staticmethod
    def compute_ssnr(clean, estimate, frame_length=1440, frame_shift=720, C_min=-10, C_max=35):
        clean = clean.squeeze().cpu().numpy()
        estimate = estimate.squeeze().cpu().numpy()
        min_len = min(len(clean), len(estimate))
        clean = clean[:min_len]
        estimate = estimate[:min_len]
        n_frames = (min_len - frame_length) // frame_shift + 1
        if n_frames < 1:
            return float('nan')
        ssnr_list = []
        for i in range(n_frames):
            start = i * frame_shift
            end = start + frame_length
            c_seg = clean[start:end]
            e_seg = estimate[start:end]
            num = np.sum(c_seg ** 2)
            den = np.sum((c_seg - e_seg) ** 2)
            # Proteggi il rapporto
            ratio = num / (den + 1e-8)
            if ratio <= 0 or not np.isfinite(ratio):
                snr = C_min
            else:
                snr = 10 * np.log10(ratio)
            snr = np.clip(snr, C_min, C_max)
            ssnr_list.append(snr)
        return np.mean(ssnr_list)

    @staticmethod
    def align_tensors(*tensors):
        min_len = min(t.shape[-1] for t in tensors)
        return [t[..., :min_len] for t in tensors]

def save_audio_sample(clean_wave, noisy_wave, predicted_wave, sample_idx, 
                     sample_rate, save_dir='/kaggle/working'):
    """Salva i tre audio (clean, noisy, denoised) per un campione specifico"""
    
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    
    def to_cpu_tensor(tensor):
        if torch.is_tensor(tensor):
            return tensor.detach().cpu().unsqueeze(0) if tensor.dim() == 1 else tensor.detach().cpu()
        return torch.tensor(tensor).unsqueeze(0)
    
    # Converti a tensori CPU
    clean_cpu = to_cpu_tensor(clean_wave)
    noisy_cpu = to_cpu_tensor(noisy_wave)
    predicted_cpu = to_cpu_tensor(predicted_wave)
    
    # Salva i file audio
    files_saved = []
    
    try:
        clean_file = save_path / f'sample_{sample_idx:03d}_clean.wav'
        torchaudio.save(str(clean_file), clean_cpu, sample_rate)
        files_saved.append(str(clean_file))
        
        noisy_file = save_path / f'sample_{sample_idx:03d}_noisy.wav'
        torchaudio.save(str(noisy_file), noisy_cpu, sample_rate)
        files_saved.append(str(noisy_file))
        
        denoised_file = save_path / f'sample_{sample_idx:03d}_denoised.wav'
        torchaudio.save(str(denoised_file), predicted_cpu, sample_rate)
        files_saved.append(str(denoised_file))
        
        print(f" Audio salvati per campione {sample_idx}:")
        for file in files_saved:
            print(f"   • {Path(file).name}")
            
    except Exception as e:
        print(f" Errore salvataggio audio campione {sample_idx}: {e}")
    
    return files_saved

def evaluate_testset(model, test_loader, audio_processor, snr_calculator, 
                    sample_rate, save_audio_for=None, plot_samples=None, plotter=None):
    """Valutazione del test set con opzione salvataggio audio"""
    
    model.eval()
    snr_values = []
    snr_improvements = []
    ssnr_values = []
    ssnr_improvements = []
    
    print(" Avvio valutazione test set...")
    
    with torch.no_grad():
        for i, (noisy_input, clean_input) in enumerate(tqdm(test_loader, desc="Processing")):
            try:
                noisy_input = noisy_input.to(device, non_blocking=True)
                clean_input = clean_input.to(device, non_blocking=True)
                
                # Predizione
                predicted_output = model(noisy_input, is_istft=True)
                
                # Conversione a waveform
                predicted_wave = audio_processor.ensure_waveform(predicted_output)
                clean_wave = audio_processor.ensure_waveform(clean_input)
                noisy_wave = audio_processor.ensure_waveform(noisy_input)
                
                # Allineamento
                clean_aligned, predicted_aligned, noisy_aligned = snr_calculator.align_tensors(
                    clean_wave, predicted_wave, noisy_wave)
                
                # Calcolo SNR
                snr_pred = snr_calculator.compute_snr(clean_aligned, predicted_aligned)
                snr_noisy = snr_calculator.compute_snr(clean_aligned, noisy_aligned)
                snr_improvement = snr_pred - snr_noisy

                # SSNR e miglioramento
                ssnr_pred = snr_calculator.compute_ssnr(clean_aligned, predicted_aligned)
                ssnr_noisy = snr_calculator.compute_ssnr(clean_aligned, noisy_aligned)
                ssnr_improvement = ssnr_pred - ssnr_noisy
                
                snr_values.append(snr_pred)
                snr_improvements.append(snr_improvement)
                ssnr_values.append(ssnr_pred)
                ssnr_improvements.append(ssnr_improvement)
                
                # Salva audio se richiesto per questo campione
                if save_audio_for and (i + 1) in save_audio_for:
                    save_audio_sample(clean_aligned, noisy_aligned, predicted_aligned, 
                                    i + 1, sample_rate)
                # Salva plot se richiesto
                if plotter and plot_samples and (i + 1) in plot_samples:
                    plot_path = plotter.plot(clean_aligned, noisy_aligned, predicted_aligned, i + 1, snr_pred, snr_improvement)
                    print(f" Plot salvato per sample {i+1}: {plot_path}")
                
            except Exception as e:
                print(f" Errore campione {i+1}: {e}")
                continue
    
    return snr_values, snr_improvements, ssnr_values, ssnr_improvements

def plot_results(snr_values, snr_improvements, save_dir='/kaggle/working'):
    """Crea i due istogrammi delle distribuzioni"""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Distribuzione SNR
    ax1.hist(snr_values, bins=40, alpha=0.7, color='blue', edgecolor='navy')
    mean_snr = np.mean(snr_values)
    # ax1.set_xlim(0, 30) PER SETTARE DIMENSIONI FISSE PER L'ASCISSA
    ax1.axvline(mean_snr, color='red', linestyle='--', linewidth=2,
               label=f'Media: {mean_snr:.2f} dB')
    ax1.set_title('Distribuzione SNR Predicted vs Clean')
    ax1.set_xlabel('SNR (dB)')
    ax1.set_ylabel('Frequenza')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Distribuzione miglioramenti
    ax2.hist(snr_improvements, bins=40, alpha=0.7, color='green', edgecolor='darkgreen')
    mean_imp = np.mean(snr_improvements)
    ax2.axvline(mean_imp, color='red', linestyle='--', linewidth=2,
               label=f'Media: {mean_imp:.2f} dB')
    ax2.axvline(0, color='black', linestyle='-', alpha=0.5, label='No improvement')
    ax2.set_title('Distribuzione SNR Improvements')
    ax2.set_xlabel('SNR Improvement (dB)')
    ax2.set_ylabel('Frequenza')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Salva il grafico
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    plot_file = save_path / 'snr_distributions.png'
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    plt.show()
    
    return str(plot_file)

def main(save_audio_samples=None, plot_samples=None):
    """Funzione principale - specifica save_audio_samples=[1,5,10] per salvare audio"""
    
    # Inizializza componenti
    audio_processor = AudioProcessor(N_FFT, HOP_LENGTH, SAMPLE_RATE)
    plotter = SingleSamplePlotter(SAMPLE_RATE)
    snr_calculator = SNRCalculator()
    
    # Carica modello
    print(" Caricamento modello...")
    model = DCUnet20(n_fft=N_FFT, hop_length=HOP_LENGTH).to(device)
    model.load_state_dict(torch.load("/kaggle/input/pretrainedweights/Pretrained_Weights/Noise2Noise/white.pth", 
                                   map_location=device))
    
    print(f" Setup completato - Device: {device}")
    
    # Valutazione
    snr_values, snr_improvements, ssnr_values, ssnr_improvements = evaluate_testset(
        model=model,
        test_loader=test_loader_single_unshuffled,
        audio_processor=audio_processor,
        snr_calculator=snr_calculator,
        sample_rate=SAMPLE_RATE,
        save_audio_for=save_audio_samples,
        plot_samples=plot_samples,
        plotter=plotter
    )
    
    # Stampa risultati come nell'immagine
    print("\n" + "="*60)
    print(" RISULTATI FINALI DEL TEST SET")
    print("="*60)
    print(f"Totale campioni processati: {len(snr_values)}")
    print(f"SNR medio predicted vs clean: {np.mean(snr_values):.2f} ± {np.std(snr_values):.2f} dB")
    print(f"SNR improvement medio: {np.mean(snr_improvements):.2f} ± {np.std(snr_improvements):.2f} dB")
    print(f"SNR minimo: {np.min(snr_values):.2f} dB")
    print(f"SNR massimo: {np.max(snr_values):.2f} dB")
    print(f"Mediana SNR: {np.median(snr_values):.2f} dB")

    print(f"SSNR medio predicted vs clean: {np.mean(ssnr_values):.2f} ± {np.std(ssnr_values):.2f} dB")
    print(f"SSNR improvement medio: {np.mean(ssnr_improvements):.2f} ± {np.std(ssnr_improvements):.2f} dB")

    
    improvement_rate = (sum(1 for x in snr_improvements if x > 0) / len(snr_improvements)) * 100
    print(f"Campioni con miglioramento: {sum(1 for x in snr_improvements if x > 0)}/{len(snr_improvements)} ({improvement_rate:.1f}%)")
    
    # Crea grafici
    plot_file = plot_results(snr_values, snr_improvements)
    print(f"\n💾 Grafico salvato: {plot_file}")
    
    return snr_values, snr_improvements

# ESECUZIONE
# Per salvare audio di campioni specifici, usa:
# results = main(save_audio_samples=[1, 5, 10])  # salva audio dei campioni 1, 5 e 10
# 
# Per non salvare audio:
# results = main()

if __name__ == "__main__":
    # Cambia qui i numeri dei campioni di cui vuoi salvare l'audio
    SAMPLES_TO_SAVE_AND_PLOT = [1, 3, 10]  # Modifica con i numeri che preferisci
    
    print(" Avvio valutazione...")
    results = main(save_audio_samples=SAMPLES_TO_SAVE_AND_PLOT, plot_samples=SAMPLES_TO_SAVE_AND_PLOT)
    print(" Valutazione completata!")