# Denoising intera canzone

# Import librerie:

In [1]:
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import matplotlib
matplotlib.use('Agg') # backend per salvataggio file, no GUI 
import matplotlib.pyplot as plt
plt.ioff() # Disabilita modalità interattiva     
from torch.utils.data import Dataset, DataLoader
import gc
from tqdm import tqdm
from pathlib import Path 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Parametri per la trasformazione STFT

In [3]:
SAMPLE_RATE = 44100
n_fft = 2048 # grandezza della finestra (risoluzione in frequenza) - nel paper è 3072, ottima per il parlato
hop_length = 512 # salto tra una finestra e l’altra (risoluzione temporale) - nel paper è 768

window = torch.hann_window(n_fft).to(device)

# Definizione dei diversi layer

### Layer convoluzionale per segnali complessi:

In [5]:
class ComplexConv2d(nn.Module):  # convoluzione 2D su numeri complessi
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
      super().__init__()

      self.in_channels = in_channels
      self.out_channels = out_channels
      self.kernel_size = kernel_size
      self.padding = padding
      self.stride = stride

      # crea una convoluzione per la parte reale:
      self.real_conv = nn.Conv2d(in_channels=self.in_channels,
                                 out_channels=self.out_channels,
                                 kernel_size=self.kernel_size,
                                 padding=self.padding,
                                 stride=self.stride)

      # crea un’altra convoluzione per la parte immaginaria.
      # Nota: è separata, quindi ha i suoi pesi e bias distinti.
      self.im_conv = nn.Conv2d(in_channels=self.in_channels,
                               out_channels=self.out_channels,
                               kernel_size=self.kernel_size,
                               padding=self.padding,
                               stride=self.stride)

      # Glorot initialization.
      nn.init.xavier_normal_(self.real_conv.weight)
      nn.init.xavier_normal_(self.im_conv.weight)

  def forward(self, x):  # x: è un tensore che contiene, sull’ultima dimensione, la parte reale e immaginaria
        x_real = x[..., 0]
        x_im = x[..., 1]

        # calcolo convoluzione complessa
        c_real = self.real_conv(x_real) - self.im_conv(x_im)
        c_im = self.im_conv(x_real) + self.real_conv(x_im)

        # combino le due parti (reale e immaginaria) di nuovo insieme, lungo l’ultima dimensione (dim = -1), per restituire un tensore complesso.
        output = torch.stack([c_real, c_im], dim=-1)
        return output

### Layer per deconvoluzione di segnali complessi:

In [6]:
class ComplexConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=0, padding=0):
        super().__init__()

        self.in_channels = in_channels

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_padding = output_padding
        self.padding = padding
        self.stride = stride

        self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels,
                                            out_channels=self.out_channels,
                                            kernel_size=self.kernel_size,
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)

        self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels,
                                            out_channels=self.out_channels,
                                            kernel_size=self.kernel_size,
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)


        # Glorot initialization.
        nn.init.xavier_normal_(self.real_convt.weight)
        nn.init.xavier_normal_(self.im_convt.weight)


    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]

        ct_real = self.real_convt(x_real) - self.im_convt(x_im)
        ct_im = self.im_convt(x_real) + self.real_convt(x_im)

        output = torch.stack([ct_real, ct_im], dim=-1)
        return output

### Layer per la batch normalization di segnali complessi:

In [7]:
class ComplexBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                      affine=self.affine, track_running_stats=self.track_running_stats)
        self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                    affine=self.affine, track_running_stats=self.track_running_stats)

    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]

        n_real = self.real_b(x_real)
        n_im = self.im_b(x_im)

        output = torch.stack([n_real, n_im], dim=-1)
        return output

### Layer Encoder:

In [8]:
class Encoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0)):
        super().__init__()

        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding

        self.cconv = ComplexConv2d(in_channels=self.in_channels, out_channels=self.out_channels,
                             kernel_size=self.filter_size, stride=self.stride_size, padding=self.padding)

        self.cbn = ComplexBatchNorm2d(num_features=self.out_channels)

        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x):

        conved = self.cconv(x)
        normed = self.cbn(conved)
        acted = self.leaky_relu(normed)

        return acted

### Layer Decoder:

In [9]:
class Decoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45,
                 output_padding=(0,0), padding=(0,0), last_layer=False):
        super().__init__()

        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_padding = output_padding
        self.padding = padding

        self.last_layer = last_layer

        self.cconvt = ComplexConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels,
                             kernel_size=self.filter_size, stride=self.stride_size, output_padding=self.output_padding, padding=self.padding)

        self.cbn = ComplexBatchNorm2d(num_features=self.out_channels)

        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x):

        conved = self.cconvt(x)

        if not self.last_layer:
            normed = self.cbn(conved)
            output = self.leaky_relu(normed)
        else:
            m_phase = conved / (torch.abs(conved) + 1e-8)
            m_mag = torch.tanh(torch.abs(conved))
            output = m_phase * m_mag

        return output

In [12]:
def stft_to_waveform(stft_tensor, n_fft=n_fft, hop_length=hop_length):
    """Converte STFT in waveform"""
    if stft_tensor.dim() == 5:  # [batch, channel, freq, time, 2]
        stft_tensor = torch.squeeze(stft_tensor, 1)
    elif stft_tensor.dim() == 4:  # [batch, freq, time, 2]
        pass
    else:
        print(f"Unexpected tensor dimensions: {stft_tensor.shape}")
        return None
    
    # Converto in complesso
    complex_tensor = torch.complex(stft_tensor[..., 0], stft_tensor[..., 1])
    
    # Applico ISTFT
    waveform = torch.istft(complex_tensor, n_fft=n_fft, hop_length=hop_length, 
                          window=window, normalized=True)
    return waveform


# Modello a 20 layer della DCUNet

In [16]:
class DCUnet20(nn.Module):
    """
    U-Net complessa che predice maschera in STFT (real/imag). Output same-shape dello STFT input.
    """
    def __init__(self, n_fft=2048, hop_length=512):
        super().__init__()

        self.n_fft = n_fft
        self.hop_length = hop_length

        self.set_size(model_complexity=32, input_channels=1, model_depth=20)

        # costruzione degli encoder
        self.encoders = []
        self.model_length = 20 // 2  # → 10 encoder e 10 decoder

        for i in range(self.model_length):
            module = Encoder(in_channels=self.enc_channels[i], out_channels=self.enc_channels[i + 1],
                             filter_size=self.enc_kernel_sizes[i], stride_size=self.enc_strides[i], padding=self.enc_paddings[i])
            self.add_module("encoder{}".format(i), module)
            self.encoders.append(module)

        # costruzione dei decoder
        self.decoders = []
        
        for i in range(self.model_length):
            if i != self.model_length - 1:
                # Il primo decoder non deve sommare le skip connections nell'input
                if i == 0:
                    # Primo decoder: solo i canali dall'encoder finale
                    in_channels = self.dec_channels[i]
                else:
                    # Altri decoder: canali decoder + skip connection
                    in_channels = self.dec_channels[i] + self.enc_channels[self.model_length - i]
                    
                module = Decoder(in_channels=in_channels, 
                                out_channels=self.dec_channels[i + 1],
                                filter_size=self.dec_kernel_sizes[i], 
                                stride_size=self.dec_strides[i], 
                                padding=self.dec_paddings[i],
                                output_padding=self.dec_output_padding[i])
            else:
                # Ultimo decoder
                in_channels = self.dec_channels[i] + self.enc_channels[self.model_length - i]
                module = Decoder(in_channels=in_channels, 
                                out_channels=self.dec_channels[i + 1],
                                filter_size=self.dec_kernel_sizes[i], 
                                stride_size=self.dec_strides[i], 
                                padding=self.dec_paddings[i],
                                output_padding=self.dec_output_padding[i], 
                                last_layer=True)
            
            self.add_module("decoder{}".format(i), module)
            self.decoders.append(module)



    def forward(self, x, is_istft=True):
        orig_x = x
        xs = []
        
        # Controllo che l'input abbia dimensioni corrette
        if x.dim() == 4:  # [batch, freq, time, 2]
            x = x.unsqueeze(1)  # [batch, 1, freq, time, 2]
        
        # Encoder (mantieni dimensioni consistenti)
        for i, encoder in enumerate(self.encoders):
            xs.append(x)
            x = encoder(x)
            #print(f"Encoder {i}: {x.shape}")
        
        # Decoder con controlli
        p = x
        for i, decoder in enumerate(self.decoders):
            p = decoder(p)
            #print(f"Decoder {i} output: {p.shape}")
            
            if i < self.model_length - 1:
                skip_connection = xs[self.model_length - 1 - i]
                #print(f"Skip connection {i}: {skip_connection.shape}")
                
                # CONTROLLO delle dimensioni
                if p.shape[2:4] != skip_connection.shape[2:4]:
                    print(f"ERRORE: Dimensioni incompatibili!")
                    print(f"Decoder: {p.shape}, Skip: {skip_connection.shape}")
                    return torch.zeros_like(orig_x)
                
                p = torch.cat([p, skip_connection], dim=1)
        
        # Output finale
        mask = p
        output = mask * orig_x.unsqueeze(1) if orig_x.dim() == 4 else mask * orig_x
        
        return output

    def set_size(self, model_complexity, model_depth=20, input_channels=1):
      # definisce tutte le dimensioni e i parametri per encoder e decoder, specifici per la versione a 20 layer

        if model_depth == 20:
            self.enc_channels = [input_channels,
                                 model_complexity,
                                 model_complexity,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 4]

            self.enc_kernel_sizes = [(7, 1),
                                     (1, 7),
                                     (6, 4),
                                     (7, 5),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3),
                                     (5, 3)]

            self.enc_strides = [(1, 1),
                                (1, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1),
                                (2, 2),
                                (2, 1)]

            self.enc_paddings = [(3, 0),
                                 (0, 3),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (2, 0),
                                 (0, 0)]

            self.dec_channels = [model_complexity * 4,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity * 2,
                                 model_complexity,
                                 model_complexity,
                                 input_channels]

            self.dec_kernel_sizes = [(6, 3),
                                     (3, 3),
                                     (6, 3),
                                     (6, 3),
                                     (6, 3),
                                     (6, 4),
                                     (8, 5),
                                     (7, 5),
                                     (1, 7),
                                     (7, 1)]

            self.dec_strides = [(2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (2, 1), #
                                (2, 2), #
                                (1, 1),
                                (1, 1)]

            self.dec_paddings = [(0, 0),
                                 (1, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 0),
                                 (0, 3),
                                 (3, 0)]

            self.dec_output_padding = [(0,0),
                                       (1,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0),
                                       (0,0)]
        else:
            raise ValueError("Unknown model depth : {}".format(model_depth))


In [17]:
def debug_shapes(model, sample_input):
    """Debug delle dimensioni attraverso il modello"""
    #print(f"Input shape: {sample_input.shape}")
    
    with torch.no_grad():
        output = model(sample_input, is_istft=False)
        #print(f"Output shape: {output.shape}")
        
        if output.shape != sample_input.shape:
            print("ERRORE: Output != Input shape!")
            return False
        else:
            #print("Dimensioni corrette")
            return True

# Prova il denoising su un'intera canzone

In [18]:
CHUNK_DURATION = 3.74   # Processa audio in chunk di 10 secondi
OVERLAP = 0.3  # Overlap del 50% tra chunk per evitare artefatti

# Path del modello pre-allenato
MODEL_PATH = "/kaggle/input/white/pytorch/default/1/dc20_white_best.pth"    # MODIFICA CON IL TUO PATH
INPUT_AUDIO_PATH = "/kaggle/input/musdb18-whitenoiseonly/test/input/INPUT-S10N1-(WHITE_NOISE)-(forkupines-semantics).wav"  # MODIFICA CON IL TUO FILE
OUTPUT_AUDIO_PATH = "/kaggle/working/canzone_denoised.wav"

# =====================================================
# FUNZIONI HELPER
# =====================================================

def load_audio(file_path, sample_rate=SAMPLE_RATE):
    """Carica file audio e converte a mono se necessario"""
    waveform, sr = torchaudio.load(file_path)
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    return waveform, sample_rate

def process_chunk(model, chunk_waveform, window):
    """Processa un singolo chunk di audio"""
    # STFT
    stft = torch.stft(chunk_waveform, n_fft=n_fft, hop_length=hop_length,
                     window=window, normalized=True, return_complex=True)
    stft = torch.view_as_real(stft)
    
    # Aggiungi dimensione batch
    stft = stft.unsqueeze(0)
    
    # Applica modello
    with torch.no_grad():
        pred_stft = model(stft, is_istft=False)
    
    # Rimuovi dimensione batch e channel se presenti
    if pred_stft.dim() == 5:
        pred_stft = pred_stft.squeeze(0).squeeze(0)
    elif pred_stft.dim() == 4:
        pred_stft = pred_stft.squeeze(0)
    
    # ISTFT
    complex_tensor = torch.complex(pred_stft[..., 0], pred_stft[..., 1])
    denoised_chunk = torch.istft(complex_tensor, n_fft=n_fft, hop_length=hop_length,
                                window=window, normalized=True)
    
    return denoised_chunk

def denoise_long_audio(model, audio_path, output_path):
    """Processa un file audio lungo dividendolo in chunk"""
    
    print(f"Caricamento audio da: {audio_path}")
    waveform, sr = load_audio(audio_path)
    waveform = waveform.to(device)
    
    total_samples = waveform.shape[-1]
    chunk_samples = int(CHUNK_DURATION * sr)
    overlap_samples = int(chunk_samples * OVERLAP)
    hop_samples = chunk_samples - overlap_samples
    
    print(f"Audio totale: {total_samples/sr:.2f} secondi")
    print(f"Chunk size: {chunk_samples} samples ({CHUNK_DURATION}s)")
    print(f"Overlap: {overlap_samples} samples")
    
    
    # Buffer per output
    output_waveform = torch.zeros_like(waveform)
    weight_sum = torch.zeros_like(waveform)
    
    # Finestra di fade per overlap
    fade_window = torch.hann_window(chunk_samples).to(device)
    
    # Processa chunk per chunk
    num_chunks = (total_samples - overlap_samples) // hop_samples + 1
    
    print(f"Processamento {num_chunks} chunks...")
    for i in tqdm(range(num_chunks)):
        start_idx = i * hop_samples
        end_idx = min(start_idx + chunk_samples, total_samples)
        
        # Estrai chunk
        chunk = waveform[:, start_idx:end_idx]
        
        # Pad se necessario (ultimo chunk)
        if chunk.shape[-1] < chunk_samples:
            padding = chunk_samples - chunk.shape[-1]
            chunk = torch.nn.functional.pad(chunk, (0, padding))
        
        # Processa chunk
        denoised_chunk = process_chunk(model, chunk, window)
        
        # Assicura dimensioni corrette
        if denoised_chunk.dim() == 1:
            denoised_chunk = denoised_chunk.unsqueeze(0)
        
        # Taglia il padding se necessario
        actual_length = min(end_idx - start_idx, denoised_chunk.shape[-1])
        denoised_chunk = denoised_chunk[:, :actual_length]
        
        # Applica fade window per smooth blending
        fade_length = min(actual_length, fade_window.shape[0])
        denoised_chunk[:, :fade_length] *= fade_window[:fade_length]
        
        # Accumula nell'output con overlap-add
        output_waveform[:, start_idx:start_idx+actual_length] += denoised_chunk
        weight_sum[:, start_idx:start_idx+actual_length] += fade_window[:fade_length]
    
    # Normalizza per i pesi
    mask = weight_sum > 0
    output_waveform[mask] /= weight_sum[mask]
    
    # Normalizza audio per evitare clipping
    max_val = torch.abs(output_waveform).max()
    if max_val > 0.95:
        output_waveform = output_waveform * 0.95 / max_val
    
    # Salva audio
    print(f"Salvataggio audio denoised in: {output_path}")
    torchaudio.save(output_path, output_waveform.cpu(), sr)
    
    return output_waveform

# =====================================================
# MAIN
# =====================================================

def main():
    print("="*50)
    print("DCUnet20 Audio Denoising")
    print("="*50)
    
    # Carica modello
    print("\nCaricamento modello...")
    model = DCUnet20(n_fft=n_fft, hop_length=hop_length).to(device)
    
    # Carica i pesi
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    print("Modello caricato con successo!")
    
    # Processa audio
    print("\nProcessamento audio...")
    denoised_waveform = denoise_long_audio(model, INPUT_AUDIO_PATH, OUTPUT_AUDIO_PATH)
    
    print("\n" + "="*50)
    print("Processo completato!")
    print(f"Audio denoised salvato in: {OUTPUT_AUDIO_PATH}")
    print("="*50)

if __name__ == "__main__":
    
    main()

DCUnet20 Audio Denoising

Caricamento modello...
Modello caricato con successo!

Processamento audio...
Caricamento audio da: /kaggle/input/musdb18-whitenoiseonly/test/input/INPUT-S10N1-(WHITE_NOISE)-(forkupines-semantics).wav
Audio totale: 273.39 secondi
Chunk size: 164934 samples (3.74s)
Overlap: 49480 samples
Processamento 104 chunks...


100%|██████████| 104/104 [00:13<00:00,  7.81it/s]


Salvataggio audio denoised in: /kaggle/working/canzone_denoised.wav

Processo completato!
Audio denoised salvato in: /kaggle/working/canzone_denoised.wav
