# Noise Remover


I dataset di train e test devono essere già stati generati prima.

# Import librerie:

In [None]:
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import matplotlib
matplotlib.use('Agg') # backend per salvataggio file, no GUI 
import matplotlib.pyplot as plt
plt.ioff() # Disabilita modalità interattiva     
from torch.utils.data import Dataset, DataLoader
import gc
from tqdm import tqdm
from pathlib import Path 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Indicazione Path per i dati di input e target:

In [None]:
#INSERISCI QUI IL PATH AL DATASET CON LE CANZONI SEGMENTATE
database_dir = Path("/kaggle/input/white-noise-def/filtered_noise2noise_db_white_noise")
base_dir = Path("results")
(base_dir / "Weights").mkdir(parents=True, exist_ok=True)
(base_dir / "Samples").mkdir(parents=True, exist_ok=True)

TRAIN_INPUT_DIR = database_dir / 'train' / 'input'
TRAIN_TARGET_DIR = database_dir / 'train'/ 'target'

TEST_NOISY_DIR = database_dir / 'test' / 'input'
TEST_CLEAN_DIR = database_dir / 'test'/ 'target'

## Parametri per la trasformazione STFT

In [None]:
SAMPLE_RATE = 44100
n_fft = 2048 # grandezza della finestra (risoluzione in frequenza) - nel paper è 3072, ottima per il parlato
hop_length = 512 # salto tra una finestra e l’altra (risoluzione temporale) - nel paper è 768

window = torch.hann_window(n_fft).to(device)

# Dichiarazioni del Dataset e del Dataloader

In [None]:
class SimpleNoise2NoiseDataset(Dataset):
    def __init__(self, noisy_file_set_A, noisy_file_set_B, n_fft=1024, hop_length=256):
        super().__init__()
        
        self.noisy_A = sorted(noisy_file_set_A)
        self.noisy_B = sorted(noisy_file_set_B)
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.window = window
        
        assert len(self.noisy_A) == len(self.noisy_B), "Input e target devono avere lo stesso numero di file"
    
    def load_audio_file(self, file_path):
        """Carica un file audio completo, rende mono e resample a SAMPLE_RATE"""
        waveform, sr = torchaudio.load(file_path)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        return waveform
    
    def __len__(self):
        return len(self.noisy_A)
    
    def __getitem__(self, index):
        fileA = self.noisy_A[index]
        fileB = self.noisy_B[index]
        
        x1 = self.load_audio_file(fileA)
        x2 = self.load_audio_file(fileB)

        #Restituisce STFT complessa in formato real/imag
        x1_stft = torch.stft(x1, n_fft=self.n_fft, hop_length=self.hop_length,
                           window=self.window.to(x1.device), normalized=True, return_complex=True)
        x1_stft = torch.view_as_real(x1_stft)

        x2_stft = torch.stft(x2, n_fft=self.n_fft, hop_length=self.hop_length,
                           window=self.window.to(x1.device), normalized=True, return_complex=True)
        x2_stft = torch.view_as_real(x2_stft)

        return x1_stft, x2_stft


files_noise_input = sorted(list(TRAIN_INPUT_DIR.rglob("*.wav")))
files_noise_target = sorted(list(TRAIN_TARGET_DIR.rglob("*.wav")))
test_noisy_files = sorted(list(TEST_NOISY_DIR.rglob('*.wav')))
test_clean_files = sorted(list(TEST_CLEAN_DIR.rglob('*.wav')))

print("Numero di segmenti per il train:",len(files_noise_input))
print("Numero di segmenti per il test:",len(test_noisy_files))

noise2noise_dataset = SimpleNoise2NoiseDataset(
    files_noise_input, files_noise_target, n_fft, hop_length
)

test_dataset = SimpleNoise2NoiseDataset(
    test_noisy_files, test_clean_files, n_fft, hop_length
)

train_loader = DataLoader(noise2noise_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True)

test_loader_single_unshuffled = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Definizione dei diversi layer

### Layer convoluzionale per segnali complessi:

In [None]:
class ComplexConv2d(nn.Module):  # convoluzione 2D su numeri complessi
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
      super().__init__()

      self.in_channels = in_channels
      self.out_channels = out_channels
      self.kernel_size = kernel_size
      self.padding = padding
      self.stride = stride

      # crea una convoluzione per la parte reale:
      self.real_conv = nn.Conv2d(in_channels=self.in_channels,
                                 out_channels=self.out_channels,
                                 kernel_size=self.kernel_size,
                                 padding=self.padding,
                                 stride=self.stride)

      # crea un’altra convoluzione per la parte immaginaria.
      # Nota: è separata, quindi ha i suoi pesi e bias distinti.
      self.im_conv = nn.Conv2d(in_channels=self.in_channels,
                               out_channels=self.out_channels,
                               kernel_size=self.kernel_size,
                               padding=self.padding,
                               stride=self.stride)

      # Glorot initialization.
      nn.init.xavier_normal_(self.real_conv.weight)
      nn.init.xavier_normal_(self.im_conv.weight)

  def forward(self, x):  # x: è un tensore che contiene, sull’ultima dimensione, la parte reale e immaginaria
        x_real = x[..., 0]
        x_im = x[..., 1]

        # calcolo convoluzione complessa
        c_real = self.real_conv(x_real) - self.im_conv(x_im)
        c_im = self.im_conv(x_real) + self.real_conv(x_im)

        # combino le due parti (reale e immaginaria) di nuovo insieme, lungo l’ultima dimensione (dim = -1), per restituire un tensore complesso.
        output = torch.stack([c_real, c_im], dim=-1)
        return output

### Layer per deconvoluzione di segnali complessi:

In [None]:
class ComplexConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=0, padding=0):
        super().__init__()

        self.in_channels = in_channels

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_padding = output_padding
        self.padding = padding
        self.stride = stride

        self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels,
                                            out_channels=self.out_channels,
                                            kernel_size=self.kernel_size,
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)

        self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels,
                                            out_channels=self.out_channels,
                                            kernel_size=self.kernel_size,
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)


        # Glorot initialization.
        nn.init.xavier_normal_(self.real_convt.weight)
        nn.init.xavier_normal_(self.im_convt.weight)


    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]

        ct_real = self.real_convt(x_real) - self.im_convt(x_im)
        ct_im = self.im_convt(x_real) + self.real_convt(x_im)

        output = torch.stack([ct_real, ct_im], dim=-1)
        return output

### Layer per la batch normalization di segnali complessi:

In [None]:
class ComplexBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                      affine=self.affine, track_running_stats=self.track_running_stats)
        self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                    affine=self.affine, track_running_stats=self.track_running_stats)

    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]

        n_real = self.real_b(x_real)
        n_im = self.im_b(x_im)

        output = torch.stack([n_real, n_im], dim=-1)
        return output

### Layer Encoder:

In [None]:
class Encoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0)):
        super().__init__()

        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding

        self.cconv = ComplexConv2d(in_channels=self.in_channels, out_channels=self.out_channels,
                             kernel_size=self.filter_size, stride=self.stride_size, padding=self.padding)

        self.cbn = ComplexBatchNorm2d(num_features=self.out_channels)

        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x):

        conved = self.cconv(x)
        normed = self.cbn(conved)
        acted = self.leaky_relu(normed)

        return acted

### Layer Decoder:

In [None]:
class Decoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45,
                 output_padding=(0,0), padding=(0,0), last_layer=False):
        super().__init__()

        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_padding = output_padding
        self.padding = padding

        self.last_layer = last_layer

        self.cconvt = ComplexConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels,
                             kernel_size=self.filter_size, stride=self.stride_size, output_padding=self.output_padding, padding=self.padding)

        self.cbn = ComplexBatchNorm2d(num_features=self.out_channels)

        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x):

        conved = self.cconvt(x)

        if not self.last_layer:
            normed = self.cbn(conved)
            output = self.leaky_relu(normed)
        else:
            m_phase = conved / (torch.abs(conved) + 1e-8)
            m_mag = torch.tanh(torch.abs(conved))
            output = m_phase * m_mag

        return output

# Funzione di Loss

In [None]:
def wsdr_fn(x_input, y_pred, y_target, eps=1e-8):
    """
    Loss function 
    """    
    # Verifica che y_pred abbia le dimensioni corrette
    if y_pred.dim() < 3:
        print(f"ERRORE: y_pred ha dimensioni insufficienti: {y_pred.shape}")
        print("Probabilmente il modello ha fallito durante l'inference")
        # Restituisce una loss alta per segnalare il problema
        return torch.tensor(1000.0, requires_grad=True, device=y_pred.device)
    
    # Controllo che tutti i tensori abbiano la stessa struttura
    if y_target.dim() == 5:  # [batch, channel, freq, time, 2]
        y_target = torch.squeeze(y_target, 1)  # [batch, freq, time, 2]
    
    if y_pred.dim() == 5:  # [batch, channel, freq, time, 2]
        y_pred = torch.squeeze(y_pred, 1)  # [batch, freq, time, 2]
    elif y_pred.dim() == 4:  # Già nel formato corretto
        pass
    else:
        print(f"ERRORE: Dimensioni y_pred non supportate: {y_pred.shape}")
        return torch.tensor(1000.0, requires_grad=True, device=y_pred.device)
    
    # Verifica compatibilità dimensioni
    if y_pred.shape != y_target.shape:
        print(f"ERRORE: Shape mismatch - pred: {y_pred.shape}, target: {y_target.shape}")
        return torch.tensor(1000.0, requires_grad=True, device=y_pred.device)
    
    try:
        # Converto in formato complesso
        y_true_complex = torch.complex(y_target[..., 0], y_target[..., 1])
        y_pred_complex = torch.complex(y_pred[..., 0], y_pred[..., 1])
        
        
        # Converto in dominio temporale
        y_target_time = torch.istft(y_true_complex, n_fft=n_fft, hop_length=hop_length, 
                                   window=window, normalized=True)
        y_pred_time = torch.istft(y_pred_complex, n_fft=n_fft, hop_length=hop_length,
                                 window=window, normalized=True)
        
        # Controllo che abbiano la stessa lunghezza
        min_len = min(y_target_time.shape[-1], y_pred_time.shape[-1])
        y_target_time = y_target_time[..., :min_len]
        y_pred_time = y_pred_time[..., :min_len]
        
        # Calcolo SDR loss
        def sdr_fn(true, pred, eps=1e-8):
            # Flatten per il calcolo
            true_flat = true.flatten(1)  # [batch, samples]
            pred_flat = pred.flatten(1)  # [batch, samples]
            
            # Normalizzo per evitare overflow
            true_norm = true_flat / (torch.norm(true_flat, dim=-1, keepdim=True) + eps)
            pred_norm = pred_flat / (torch.norm(pred_flat, dim=-1, keepdim=True) + eps)
            
            # Calcolo correlazione
            correlation = torch.sum(true_norm * pred_norm, dim=-1)
            return -correlation  # Massimizza correlazione
        
        sdr_loss = sdr_fn(y_target_time, y_pred_time)
        return torch.mean(sdr_loss)
        
    except Exception as e:
        print(f"ERRORE nella loss function: {e}")
        print(f"y_pred shape: {y_pred.shape}")
        print(f"y_target shape: {y_target.shape}")
        return torch.tensor(1000.0, requires_grad=True, device=y_pred.device)

# Metriche

In [None]:
def gather_all_snr_improvements(loader, model, stft_to_waveform, device):
    snr_improvements = []
    
    with torch.no_grad():
        for i, (noisy, clean) in enumerate(loader):
            #print(f"\nSample {i+1}")
            #print(" - noisy shape:", noisy.shape, "clean shape:", clean.shape)
            noisy, clean = noisy.to(device), clean.to(device)
            pred = model(noisy, is_istft=False)
            #print(" - pred shape:", pred.shape)

            clean_wave = stft_to_waveform(clean)
            noisy_wave = stft_to_waveform(noisy)
            pred_wave = stft_to_waveform(pred)
            #print(" - clean_wave shape:", clean_wave.shape,
                  #"noisy_wave shape:", noisy_wave.shape,
                  #"pred_wave shape:", pred_wave.shape)

            min_len = min(clean_wave.shape[-1], noisy_wave.shape[-1], pred_wave.shape[-1])
            #print(" - min_len:", min_len)
            clean_wave = clean_wave[..., :min_len]
            noisy_wave = noisy_wave[..., :min_len]
            pred_wave = pred_wave[..., :min_len]

            clean_np = clean_wave.squeeze().cpu().numpy()
            noisy_np = noisy_wave.squeeze().cpu().numpy()
            pred_np = pred_wave.squeeze().cpu().numpy()

            numerator = np.sum(clean_np**2)
            denominator_noisy = np.sum((clean_np - noisy_np)**2)
            denominator_pred = np.sum((clean_np - pred_np)**2)
            snr_before = 10 * np.log10(numerator / (denominator_noisy + 1e-8))
            snr_after = 10 * np.log10(numerator / (denominator_pred + 1e-8))
            improvement = snr_after - snr_before
            #print(f" - SNR before: {snr_before:.2f} dB, SNR after: {snr_after:.2f} dB, improvement: {improvement:.2f} dB")
            snr_improvements.append(improvement)

    # DEBUG FINALE
    #print(f"Lista finale: {len(snr_improvements)} valori")
    print(f"Range: {min(snr_improvements):.2f} - {max(snr_improvements):.2f} dB")

    return snr_improvements

_histogram_counter = 0
def plot_snr_histograms(full_snr_list):

    global _histogram_counter
    _histogram_counter += 1
    
    if len(full_snr_list) == 0:
        print("Lista SNR vuota!")
        return
    
    plt.figure(figsize=(10, 6))
    plt.hist(full_snr_list, bins=50, alpha=0.8, label='Test Set', color='blue')
    plt.axvline(np.mean(full_snr_list), color='red', linestyle='dashed', linewidth=2, 
                label=f'Media: {np.mean(full_snr_list):.2f} dB')
    plt.title('Distribuzione SNR Improvement - Test Set')
    plt.xlabel('SNR Improvement (dB)')
    plt.ylabel('Frequenza')
    plt.legend()

    filename = f'/kaggle/working/snr_histogram_epoch_{_histogram_counter:02d}.png'
    
    # SALVATAGGIO E VISUALIZZAZIONE
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()
    
    print("Grafico plot_snr_histograms salvato!")



In [None]:
def stft_to_waveform(stft_tensor, n_fft=n_fft, hop_length=hop_length):
    """Converte STFT in waveform"""
    if stft_tensor.dim() == 5:  # [batch, channel, freq, time, 2]
        stft_tensor = torch.squeeze(stft_tensor, 1)
    elif stft_tensor.dim() == 4:  # [batch, freq, time, 2]
        pass
    else:
        print(f"Unexpected tensor dimensions: {stft_tensor.shape}")
        return None
    
    # Converto in complesso
    complex_tensor = torch.complex(stft_tensor[..., 0], stft_tensor[..., 1])
    
    # Applico ISTFT
    waveform = torch.istft(complex_tensor, n_fft=n_fft, hop_length=hop_length, 
                          window=window, normalized=True)
    return waveform


# Allenamento epoche

In [None]:
def train_epoch(net, train_loader, loss_fn, optimizer):
    net.train()
    train_ep_loss = 0.
    counter = 0
    lr_per_batch = []
    
    for noisy_input, noisy_target in train_loader:
        noisy_input, noisy_target = noisy_input.to(device), noisy_target.to(device)
        
        
        optimizer.zero_grad()
        
        try:
            pred_x = net(noisy_input, is_istft=False)  # Mantengo in dominio STFT
            
            loss = loss_fn(noisy_input, pred_x, noisy_target)
            
            if torch.isnan(loss) or torch.isinf(loss):
                print("NaN/Inf loss detected, skipping batch")
                continue
                
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.5)
            optimizer.step()
            
            current_lr = optimizer.param_groups[0]['lr']
            lr_per_batch.append(current_lr)
            
            scheduler.step()
            
            train_ep_loss += loss.item()
            counter += 1
            
        except Exception as e:
            print(f"Error in training: {e}")
            print(f"Input shape: {noisy_input.shape}")
            print(f"Target shape: {noisy_target.shape}")
            continue
    
    avg_loss = train_ep_loss / max(counter, 1)
    
    if len(lr_per_batch) > 0:
        lr_first = lr_per_batch[0]
        lr_last = lr_per_batch[-1]
        lr_min = min(lr_per_batch)
        lr_max = max(lr_per_batch)
    else:
        lr_first = lr_last = lr_min = lr_max = float('nan')
    
    return avg_loss, {
        "lr_first_batch": lr_first,
        "lr_last_batch": lr_last,
        "lr_min": lr_min,
        "lr_max": lr_max,
        "lr_steps": len(lr_per_batch)
    }



# Validazione del modello durante il training

In [None]:
def test_epoch(net, test_loader, loss_fn, use_net=True):
    net.eval()
    test_ep_loss = 0.
    counter = 0
    
    for noisy_x, clean_x in test_loader:
        noisy_x, clean_x = noisy_x.to(device), clean_x.to(device)
        
        try:
            pred_x = net(noisy_x, is_istft=False)  # Mantiene in dominio STFT
            
            # Verifica che la predizione abbia senso
            if pred_x.dim() < 3:
                print(f"ERRORE: Predizione con dimensioni sbagliate: {pred_x.shape}")
                continue
                
            loss = loss_fn(noisy_x, pred_x, clean_x)
            
            if torch.isnan(loss) or torch.isinf(loss) or loss.item() > 100:
                print(f"Loss anomala: {loss.item()}, saltando batch")
                continue
                
            test_ep_loss += loss.item()
            counter += 1
            
        except Exception as e:
            print(f"Errore nel test batch: {e}")
            continue
    
    if counter == 0:
        print("ATTENZIONE: Nessun batch valido nel test!")
        return float('inf'), {}
    
    test_ep_loss /= counter
    
    # Calcolo metriche sui dati di test (solo se ci sono batch validi)
    try:
        # Esegui tutto il test set e raccogli tutti gli SNR improvement
        snr_full = gather_all_snr_improvements(test_loader, net, stft_to_waveform, device)
        mean_snr = np.mean(snr_full)
        negativi = np.sum(np.array(snr_full) < 0)
        print(f"Media SNR improvement test set: {mean_snr:.2f} dB — peggiorati: {negativi}/{len(snr_full)}")
        plot_snr_histograms(snr_full)


    except Exception as e:
        print(f"Errore nel calcolo metriche: {e}")
    
    gc.collect()
    torch.cuda.empty_cache()
    
    return test_ep_loss, mean_snr


# Stampa della Loss del train e del test durante l'allenamento

In [None]:
def train(net, train_loader, test_loader, loss_fn, optimizer, scheduler, epochs):

    # Debug iniziale delle dimensioni
    sample_batch = next(iter(train_loader))
    sample_input = sample_batch[0].to(device)
    
    if not debug_shapes(net, sample_input):
        print("Training interrotto - Problemi dimensionali!")
        return [], []
        
    train_losses = []
    test_losses = []
    
    best_metric = float('-inf') 
    best_path = base_dir / "Weights" / "dc20_best.pth"

    for e in tqdm(range(epochs)):
        # STAMPA LEARNING RATE ALL'INIZIO DELL'EPOCA
        train_loss, lr_stats = train_epoch(net, train_loader, loss_fn, optimizer)
    
        # Stampa riassunto LR dell’epoca
        print(
            f"[LR] Epoch {e+1}: "
            f"first={lr_stats['lr_first_batch']:.3e}, "
            f"last={lr_stats['lr_last_batch']:.3e}, "
            f"min={lr_stats['lr_min']:.3e}, "
            f"max={lr_stats['lr_max']:.3e}, "
            f"steps={lr_stats['lr_steps']}"
        )
        
        # Esegui un'epoca di test (validazione)
        with torch.no_grad():
            test_loss, mean_snr = test_epoch(net, test_loader, loss_fn, use_net=True)
            if mean_snr is not None and mean_snr > best_metric:
                best_metric = mean_snr
                torch.save(net.state_dict(), best_path)
                print(f"Nuovo best: SNR_metric {best_metric:.6f}. Salvato dc20_best.pth")
        
        #scheduler.step(test_loss)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        torch.cuda.empty_cache()
        gc.collect()
        
        print("Loss: {:.6f}...".format(train_loss),
              "Test Loss: {:.6f}".format(test_loss))

    return train_losses, test_losses

# Modello a 20 layer della DCUNet

In [None]:
class DCUnet20(nn.Module):
    """
    U-Net complessa che predice maschera in STFT (real/imag). Output same-shape dello STFT input.
    """
    def __init__(self, n_fft=2048, hop_length=512):
        super().__init__()

        self.n_fft = n_fft
        self.hop_length = hop_length

        self.set_size(model_complexity=32, input_channels=1, model_depth=20)

        # costruzione degli encoder
        self.encoders = []
        self.model_length = 20 // 2  # → 10 encoder e 10 decoder

        for i in range(self.model_length):
            module = Encoder(in_channels=self.enc_channels[i], out_channels=self.enc_channels[i + 1],
                             filter_size=self.enc_kernel_sizes[i], stride_size=self.enc_strides[i], padding=self.enc_paddings[i])
            self.add_module("encoder{}".format(i), module)
            self.encoders.append(module)

        # costruzione dei decoder
        self.decoders = []
        
        for i in range(self.model_length):
            if i != self.model_length - 1:
                # Il primo decoder non deve sommare le skip connections nell'input
                if i == 0:
                    # Primo decoder: solo i canali dall'encoder finale
                    in_channels = self.dec_channels[i]
                else:
                    # Altri decoder: canali decoder + skip connection
                    in_channels = self.dec_channels[i] + self.enc_channels[self.model_length - i]
                    
                module = Decoder(in_channels=in_channels, 
                                out_channels=self.dec_channels[i + 1],
                                filter_size=self.dec_kernel_sizes[i], 
                                stride_size=self.dec_strides[i], 
                                padding=self.dec_paddings[i],
                                output_padding=self.dec_output_padding[i])
            else:
                # Ultimo decoder
                in_channels = self.dec_channels[i] + self.enc_channels[self.model_length - i]
                module = Decoder(in_channels=in_channels, 
                                out_channels=self.dec_channels[i + 1],
                                filter_size=self.dec_kernel_sizes[i], 
                                stride_size=self.dec_strides[i], 
                                padding=self.dec_paddings[i],
                                output_padding=self.dec_output_padding[i], 
                                last_layer=True)
            
            self.add_module("decoder{}".format(i), module)
            self.decoders.append(module)



    def forward(self, x, is_istft=True):
        orig_x = x
        xs = []
        
        # Controllo che l'input abbia dimensioni corrette
        if x.dim() == 4:  # [batch, freq, time, 2]
            x = x.unsqueeze(1)  # [batch, 1, freq, time, 2]
        
        # Encoder (mantieni dimensioni consistenti)
        for i, encoder in enumerate(self.encoders):
            xs.append(x)
            x = encoder(x)
            #print(f"Encoder {i}: {x.shape}")
        
        # Decoder con controlli
        p = x
        for i, decoder in enumerate(self.decoders):
            p = decoder(p)
            #print(f"Decoder {i} output: {p.shape}")
            
            if i < self.model_length - 1:
                skip_connection = xs[self.model_length - 1 - i]
                #print(f"Skip connection {i}: {skip_connection.shape}")
                
                # CONTROLLO delle dimensioni
                if p.shape[2:4] != skip_connection.shape[2:4]:
                    print(f"ERRORE: Dimensioni incompatibili!")
                    print(f"Decoder: {p.shape}, Skip: {skip_connection.shape}")
                    return torch.zeros_like(orig_x)
                
                p = torch.cat([p, skip_connection], dim=1)
        
        # Output finale
        mask = p
        output = mask * orig_x.unsqueeze(1) if orig_x.dim() == 4 else mask * orig_x
        
        return output

    def set_size(self, model_complexity, model_depth=20, input_channels=1):
      # definisce tutte le dimensioni e i parametri per encoder e decoder, specifici per la versione a 20 layer

        if model_depth == 20:
            self.enc_channels = [input_channels,
                                 model_complexity,
                                 model_complexity,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 4]

            self.enc_kernel_sizes = [(7, 1),
                                     (1, 7),
                                     (6, 4),
                                     (7, 5),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3)]

            self.enc_strides = [(1, 1),
                                (1, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1)]

            self.enc_paddings = [(3, 0),
                                 (0, 3),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (2, 0),
                                 (0, 0)]

            self.dec_channels = [model_complexity * 4,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity,
                                 model_complexity,
                                 input_channels]

            self.dec_kernel_sizes = [(6, 3),
                                     (3, 3),
                                     (6, 3),
                                     (6, 3),
                                     (6, 3),
                                     (6, 4),
                                     (8, 5),
                                     (7, 5),
                                     (1, 7),
                                     (7, 1)]

            self.dec_strides = [(2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (1, 1),
                                (1, 1)]

            self.dec_paddings = [(0, 0),
                                 (1, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 3),
                                 (3, 0)]

            self.dec_output_padding = [(0,0),
                                       (1,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0)]
        else:
            raise ValueError("Unknown model depth : {}".format(model_depth))


In [None]:
def debug_shapes(model, sample_input):
    """Debug delle dimensioni attraverso il modello"""
    #print(f"Input shape: {sample_input.shape}")
    
    with torch.no_grad():
        output = model(sample_input, is_istft=False)
        #print(f"Output shape: {output.shape}")
        
        if output.shape != sample_input.shape:
            print("ERRORE: Output != Input shape!")
            return False
        else:
            #print("Dimensioni corrette")
            return True

# Allenamento della rete

In [None]:
gc.collect()
torch.cuda.empty_cache()

dcunet20 = DCUnet20(n_fft, hop_length).to(device)

# DEBUG DELLE DIMENSIONI
#print("Verificando dimensioni del modello...")
sample_batch = next(iter(train_loader))
sample_input = sample_batch[0].to(device)

if not debug_shapes(dcunet20, sample_input):
    print("FERMA IL TRAINING - Problemi dimensionali!")
    exit()
    
loss_fn = wsdr_fn


optimizer = torch.optim.AdamW(
        dcunet20.parameters(),
        lr=5e-5,  # Learning rate iniziale
        weight_decay=1e-4,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    

scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1.5e-4,
        epochs=10,
        steps_per_epoch=len(train_loader),
        pct_start=0.15,  # Warm-up al 15%
        anneal_strategy='cos',
        div_factor=4.0,   # max_lr/div_factor = lr iniziale
        final_div_factor=20.0  # lr finale = max_lr/final_div_factor
    )

# per riprendere l’allenamento da un checkpoint salvato in precedenza
#model_checkpoint = torch.load("/kaggle/input/segments4/pytorch/default/1/dc20_model_4.pth")
#dcunet20.load_state_dict(model_checkpoint)

# lancia l’allenamento del modello per 10 epoche, salvando le perdite di training e validation.
train_losses, validation_losses = train(dcunet20, train_loader, test_loader, loss_fn, optimizer, scheduler, 10)

# Prova dei modelli allenati con valutazioni e grafici di confronto

In [None]:
'''
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

# Device e parametri globali
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class AudioProcessor:
    """Classe per gestire elaborazione audio"""
    
    def __init__(self, n_fft, hop_length, sample_rate):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.sample_rate = sample_rate
        self.window = torch.hann_window(n_fft).to(device)

    def stft_to_waveform(self, stft_tensor):
        """Conversione STFT->waveform"""
        if stft_tensor.dim() == 5:
            stft_tensor = stft_tensor[0, 0]
        elif stft_tensor.dim() == 4:
            stft_tensor = stft_tensor[0]
        
        complex_tensor = torch.view_as_complex(stft_tensor)
        waveform = torch.istft(
            complex_tensor,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            window=self.window,
            normalized=True,
            return_complex=False
        )
        return waveform

    def ensure_waveform(self, tensor):
        """Converte tensor a waveform"""
        if tensor.dim() >= 3 and tensor.size(-1) == 2:
            return self.stft_to_waveform(tensor)
        else:
            while tensor.dim() > 1 and tensor.size(0) == 1:
                tensor = tensor.squeeze(0)
            return tensor

class SingleSamplePlotter:
    def __init__(self, sample_rate, save_dir='/kaggle/working'):
        self.sample_rate = sample_rate
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
    def plot(self, clean_wave, noisy_wave, predicted_wave, sample_idx, snr_value, snr_improvement):
        def to_numpy(tensor):
            if torch.is_tensor(tensor):
                return tensor.detach().cpu().numpy().squeeze()
            return np.asarray(tensor).squeeze()
        try:
            clean_np, noisy_np, pred_np = map(to_numpy, [clean_wave, noisy_wave, predicted_wave])
            min_len = min(len(clean_np), len(noisy_np), len(pred_np))
            clean_np = clean_np[:min_len]
            noisy_np = noisy_np[:min_len]
            pred_np = pred_np[:min_len]
            time_axis = np.arange(min_len) / self.sample_rate
            fig, axes = plt.subplots(3, 1, figsize=(12, 8))
            colors = ['#2E8B57', '#DC143C', '#1E90FF']
            signals = [('Clean (Ref)', clean_np), ('Noisy (Input)', noisy_np), ('Denoised (Output)', pred_np)]
            for i, (label, signal) in enumerate(signals):
                axes[i].plot(time_axis, signal, color=colors[i], linewidth=0.8, alpha=0.9)
                if i == 2:
                    title = f'{label} - SNR: {snr_value:.2f} dB | ΔSNR: {snr_improvement:+.2f} dB'
                else:
                    title = f'Sample {sample_idx} - {label}'
                axes[i].set_title(title, fontsize=11)
                axes[i].set_ylabel('Amplitude')
                axes[i].grid(True, alpha=0.3)
            axes[-1].set_xlabel('Time (seconds)')
            plt.tight_layout()
            filename = self.save_dir / f'sample_{sample_idx:03d}_plot.png'
            plt.savefig(filename, dpi=120, bbox_inches='tight', facecolor='white')
            plt.close()
            return str(filename)
        except Exception as e:
            print(f" Errore plot campione {sample_idx}: {e}")
            return None


class SNRCalculator:
    """Classe per calcoli SNR"""
    
    @staticmethod
    def compute_snr(clean, estimate, epsilon=1e-10):
        signal_power = torch.sum(clean ** 2)
        noise_power = torch.sum((clean - estimate) ** 2)
        noise_power = torch.clamp(noise_power, min=epsilon)
        snr = 10 * torch.log10(signal_power / noise_power)
        return snr.item()

    @staticmethod
    def compute_ssnr(clean, estimate, frame_length=1440, frame_shift=720, C_min=-10, C_max=35):
        clean = clean.squeeze().cpu().numpy()
        estimate = estimate.squeeze().cpu().numpy()
        min_len = min(len(clean), len(estimate))
        clean = clean[:min_len]
        estimate = estimate[:min_len]
        n_frames = (min_len - frame_length) // frame_shift + 1
        if n_frames < 1:
            return float('nan')
        ssnr_list = []
        for i in range(n_frames):
            start = i * frame_shift
            end = start + frame_length
            c_seg = clean[start:end]
            e_seg = estimate[start:end]
            num = np.sum(c_seg ** 2)
            den = np.sum((c_seg - e_seg) ** 2)
            # Proteggi il rapporto
            ratio = num / (den + 1e-8)
            if ratio <= 0 or not np.isfinite(ratio):
                snr = C_min
            else:
                snr = 10 * np.log10(ratio)
            snr = np.clip(snr, C_min, C_max)
            ssnr_list.append(snr)
        return np.mean(ssnr_list)



    @staticmethod
    def align_tensors(*tensors):
        min_len = min(t.shape[-1] for t in tensors)
        return [t[..., :min_len] for t in tensors]

def save_audio_sample(clean_wave, noisy_wave, predicted_wave, sample_idx, 
                     sample_rate, save_dir='/kaggle/working'):
    """Salva i tre audio (clean, noisy, denoised) per un campione specifico"""
    
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    
    def to_cpu_tensor(tensor):
        if torch.is_tensor(tensor):
            return tensor.detach().cpu().unsqueeze(0) if tensor.dim() == 1 else tensor.detach().cpu()
        return torch.tensor(tensor).unsqueeze(0)
    
    # Converti a tensori CPU
    clean_cpu = to_cpu_tensor(clean_wave)
    noisy_cpu = to_cpu_tensor(noisy_wave)
    predicted_cpu = to_cpu_tensor(predicted_wave)
    
    # Salva i file audio
    files_saved = []
    
    try:
        clean_file = save_path / f'sample_{sample_idx:03d}_clean.wav'
        torchaudio.save(str(clean_file), clean_cpu, sample_rate)
        files_saved.append(str(clean_file))
        
        noisy_file = save_path / f'sample_{sample_idx:03d}_noisy.wav'
        torchaudio.save(str(noisy_file), noisy_cpu, sample_rate)
        files_saved.append(str(noisy_file))
        
        denoised_file = save_path / f'sample_{sample_idx:03d}_denoised.wav'
        torchaudio.save(str(denoised_file), predicted_cpu, sample_rate)
        files_saved.append(str(denoised_file))
        
        print(f" Audio salvati per campione {sample_idx}:")
        for file in files_saved:
            print(f"   • {Path(file).name}")
            
    except Exception as e:
        print(f" Errore salvataggio audio campione {sample_idx}: {e}")
    
    return files_saved

def evaluate_testset(model, test_loader, audio_processor, snr_calculator, 
                    sample_rate, save_audio_for=None, plot_samples=None, plotter=None):
    """Valutazione del test set con opzione salvataggio audio"""
    
    model.eval()
    snr_values = []
    snr_improvements = []
    ssnr_values = []
    ssnr_improvements = []

    
    print(" Avvio valutazione test set...")
    
    with torch.no_grad():
        for i, (noisy_input, clean_input) in enumerate(tqdm(test_loader, desc="Processing")):
            try:
                noisy_input = noisy_input.to(device, non_blocking=True)
                clean_input = clean_input.to(device, non_blocking=True)
                
                # Predizione
                predicted_output = model(noisy_input, is_istft=True)
                
                # Conversione a waveform
                predicted_wave = audio_processor.ensure_waveform(predicted_output)
                clean_wave = audio_processor.ensure_waveform(clean_input)
                noisy_wave = audio_processor.ensure_waveform(noisy_input)
                
                # Allineamento
                clean_aligned, predicted_aligned, noisy_aligned = snr_calculator.align_tensors(
                    clean_wave, predicted_wave, noisy_wave)
                
                # Calcolo SNR
                snr_pred = snr_calculator.compute_snr(clean_aligned, predicted_aligned)
                snr_noisy = snr_calculator.compute_snr(clean_aligned, noisy_aligned)
                snr_improvement = snr_pred - snr_noisy

                # SSNR e miglioramento
                ssnr_pred = snr_calculator.compute_ssnr(clean_aligned, predicted_aligned)
                ssnr_noisy = snr_calculator.compute_ssnr(clean_aligned, noisy_aligned)
                ssnr_improvement = ssnr_pred - ssnr_noisy
                
                snr_values.append(snr_pred)
                snr_improvements.append(snr_improvement)
                ssnr_values.append(ssnr_pred)
                ssnr_improvements.append(ssnr_improvement)
                
                # Salva audio se richiesto per questo campione
                if save_audio_for and (i + 1) in save_audio_for:
                    save_audio_sample(clean_aligned, noisy_aligned, predicted_aligned, 
                                    i + 1, sample_rate)
                # Salva plot se richiesto
                if plotter and plot_samples and (i + 1) in plot_samples:
                    plot_path = plotter.plot(clean_aligned, noisy_aligned, predicted_aligned, i + 1, snr_pred, snr_improvement)
                    print(f" Plot salvato per sample {i+1}: {plot_path}")
                
            except Exception as e:
                print(f" Errore campione {i+1}: {e}")
                continue
    
    return snr_values, snr_improvements, ssnr_values, ssnr_improvements

def plot_results(snr_values, snr_improvements, save_dir='/kaggle/working'):
    """Crea i due istogrammi delle distribuzioni"""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Distribuzione SNR
    ax1.hist(snr_values, bins=40, alpha=0.7, color='blue', edgecolor='navy')
    mean_snr = np.mean(snr_values)
    # ax1.set_xlim(0, 30) PER SETTARE DIMENSIONI FISSE PER L'ASCISSA
    ax1.axvline(mean_snr, color='red', linestyle='--', linewidth=2,
               label=f'Media: {mean_snr:.2f} dB')
    ax1.set_title('Distribuzione SNR Predicted vs Clean')
    ax1.set_xlabel('SNR (dB)')
    ax1.set_ylabel('Frequenza')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Distribuzione miglioramenti
    ax2.hist(snr_improvements, bins=40, alpha=0.7, color='green', edgecolor='darkgreen')
    mean_imp = np.mean(snr_improvements)
    ax2.axvline(mean_imp, color='red', linestyle='--', linewidth=2,
               label=f'Media: {mean_imp:.2f} dB')
    ax2.axvline(0, color='black', linestyle='-', alpha=0.5, label='No improvement')
    ax2.set_title('Distribuzione SNR Improvements')
    ax2.set_xlabel('SNR Improvement (dB)')
    ax2.set_ylabel('Frequenza')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Salva il grafico
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    plot_file = save_path / 'snr_distributions.png'
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    plt.show()
    
    return str(plot_file)

def main(save_audio_samples=None, plot_samples=None):
    """Funzione principale - specifica save_audio_samples=[1,5,10] per salvare audio"""
    
    # Inizializza componenti
    audio_processor = AudioProcessor(n_fft, n_fft, SAMPLE_RATE)
    plotter = SingleSamplePlotter(SAMPLE_RATE)
    snr_calculator = SNRCalculator()
    
    # Carica modello
    print(" Caricamento modello...")
    model = DCUnet20(n_fft=n_fft, hop_length=n_fft).to(device)
    model.load_state_dict(torch.load("/kaggle/input/pretrainedweights/Pretrained_Weights/Noise2Noise/white.pth", 
                                   map_location=device))
    
    print(f" Setup completato - Device: {device}")
    
    # Valutazione
    snr_values, snr_improvements, ssnr_values, ssnr_improvements = evaluate_testset(
        model=model,
        test_loader=test_loader_single_unshuffled,
        audio_processor=audio_processor,
        snr_calculator=snr_calculator,
        sample_rate=SAMPLE_RATE,
        save_audio_for=save_audio_samples,
        plot_samples=plot_samples,
        plotter=plotter
    )
    
    # Stampa risultati come nell'immagine
    print("\n" + "="*60)
    print(" RISULTATI FINALI DEL TEST SET")
    print("="*60)
    print(f"Totale campioni processati: {len(snr_values)}")
    print(f"SNR medio predicted vs clean: {np.mean(snr_values):.2f} ± {np.std(snr_values):.2f} dB")
    print(f"SNR improvement medio: {np.mean(snr_improvements):.2f} ± {np.std(snr_improvements):.2f} dB")
    print(f"SNR minimo: {np.min(snr_values):.2f} dB")
    print(f"SNR massimo: {np.max(snr_values):.2f} dB")
    print(f"Mediana SNR: {np.median(snr_values):.2f} dB")

    print(f"SSNR medio predicted vs clean: {np.mean(ssnr_values):.2f} ± {np.std(ssnr_values):.2f} dB")
    print(f"SSNR improvement medio: {np.mean(ssnr_improvements):.2f} ± {np.std(ssnr_improvements):.2f} dB")

    
    improvement_rate = (sum(1 for x in snr_improvements if x > 0) / len(snr_improvements)) * 100
    print(f"Campioni con miglioramento: {sum(1 for x in snr_improvements if x > 0)}/{len(snr_improvements)} ({improvement_rate:.1f}%)")
    
    # Crea grafici
    plot_file = plot_results(snr_values, snr_improvements)
    print(f"\n Grafico salvato: {plot_file}")
    
    return snr_values, snr_improvements

# ESECUZIONE
# Per salvare audio di campioni specifici, usa:
# results = main(save_audio_samples=[1, 5, 10])  # salva audio dei campioni 1, 5 e 10
# 
# Per non salvare audio:
# results = main()

if __name__ == "__main__":
    # Cambia qui i numeri dei campioni di cui vuoi salvare l'audio
    SAMPLES_TO_SAVE_AND_PLOT = [1, 3, 10]  # Modifica con i numeri che preferisci
    
    print(" Avvio valutazione...")
    results = main(save_audio_samples=SAMPLES_TO_SAVE_AND_PLOT, plot_samples=SAMPLES_TO_SAVE_AND_PLOT)
    print(" Valutazione completata!")

'''