# Generazione dei dataset con i segmenti migliori delle canzoni

Import librerie:

In [None]:
from fastai.imports import *
#INSERISCI QUI IL PATH AL DATASET CON LE CANZONI INTERE CHE VUOI SEGMENTARE
database_dir = Path("/kaggle/input/musdb8k-class/dataset_n1")
base_dir = Path("results")
(base_dir / "Weights").mkdir(parents=True, exist_ok=True)
(base_dir / "Samples").mkdir(parents=True, exist_ok=True)

In [2]:
import torch
import torchaudio
import numpy as np
import torch.nn as nn
import matplotlib
matplotlib.use('Agg') # Backend non interattivo  
import matplotlib.pyplot as plt
plt.ioff() # Disabilita modalità interattiva   
%matplotlib inline                             
import IPython.display as ipd
from torch.utils.data import Dataset, DataLoader
import gc
from tqdm import tqdm, tqdm_notebook

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#determinismo CUDA GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


Indicazione Path per i dati di input e target:

In [3]:
TRAIN_INPUT_DIR = database_dir / 'train' / 'input'
TRAIN_TARGET_DIR = database_dir / 'train'/ 'target'

TEST_NOISY_DIR = database_dir / 'test' / 'input'
TEST_CLEAN_DIR = database_dir / 'test'/ 'target'

Parametri per la trasformazione STFT:

In [4]:
SAMPLE_RATE = 44100
n_fft = 2048 # grandezza della finestra (risoluzione in frequenza) - nel paper è 3072, ottima per il parlato
hop_length = 512 # salto tra una finestra e l’altra (risoluzione temporale) - nel paper è 768

window = torch.hann_window(n_fft).to(device)

# Funzioni per filtrare meglio i segmenti

In [5]:
def analyze_noise_correlation(stft_input, stft_target, clean_estimate=None):
    """
    Analizza la correlazione tra i rumori di input e target
    """
    # Se non hai il segnale pulito, stimalo come media
    if clean_estimate is None:
        # Stima il segnale pulito come media tra input e target
        clean_estimate = (stft_input + stft_target) / 2
    
    # Estrai i rumori
    noise_input = stft_input - clean_estimate
    noise_target = stft_target - clean_estimate
    
    # Converti in magnitudine
    mag_noise_input = torch.sqrt(noise_input[..., 0]**2 + noise_input[..., 1]**2)
    mag_noise_target = torch.sqrt(noise_target[..., 0]**2 + noise_target[..., 1]**2)
    
    # Calcola correlazione
    noise_input_flat = mag_noise_input.flatten()
    noise_target_flat = mag_noise_target.flatten()
    
    correlation = torch.corrcoef(torch.stack([noise_input_flat, noise_target_flat]))[0,1]
    
    return {
        'correlation': correlation.item(),
        'is_decorrelated': abs(correlation.item()) < 0.1,  # Soglia bassa
        'noise_input_power': torch.mean(mag_noise_input**2).item(),
        'noise_target_power': torch.mean(mag_noise_target**2).item()
    }


def _compute_segment_quality(self, fileA, fileB, start_sample):
    """Calcola metriche di qualità per un segmento con debug"""
    try:
        # Carica segmenti
        x1 = self.load_segment(fileA, start_sample)
        x2 = self.load_segment(fileB, start_sample)
        
        # Converti in STFT
        x1_stft = torch.stft(x1, n_fft=self.n_fft, hop_length=self.hop_length,
                           window=self.window.to(x1.device), normalized=True, return_complex=True)
        x1_stft = torch.view_as_real(x1_stft)
        
        x2_stft = torch.stft(x2, n_fft=self.n_fft, hop_length=self.hop_length,
                           window=self.window.to(x1.device), normalized=True, return_complex=True)
        x2_stft = torch.view_as_real(x2_stft)
        
        # Analizza correlazione rumori
        correlation_metrics = analyze_noise_correlation(x1_stft, x2_stft)
        
        # Analizza diversità spettrale
        diversity_metrics = compute_spectral_diversity_score(x1_stft, x2_stft)
        
        # 🔧 DEBUG: Stampa i valori per capire la distribuzione
        if hasattr(self, 'debug_count'):
            self.debug_count += 1
        else:
            self.debug_count = 1
            
        if self.debug_count <= 10:  # Stampa solo i primi 10
            print(f"Segmento {self.debug_count}:")
            print(f"  Correlazione: {correlation_metrics['correlation']:.4f}")
            print(f"  Diversità: {diversity_metrics['diversity_score']:.4f}")
        
        # Criteri più permissivi per il debug
        correlation_ok = abs(correlation_metrics['correlation']) <= 0.5  # Molto permissivo
        diversity_ok = diversity_metrics['diversity_score'] >= 0.05      # Molto permissivo
        
        if correlation_ok and diversity_ok:
            return {
                'correlation': correlation_metrics['correlation'],
                'diversity_score': diversity_metrics['diversity_score'],
                'quality_score': diversity_metrics['diversity_score'] - abs(correlation_metrics['correlation']),
                'energy_balance': abs(correlation_metrics['noise_input_power'] - correlation_metrics['noise_target_power'])
            }
        else:
            return None
            
    except Exception as e:
        print(f"Errore nel calcolo qualità: {e}")
        return None


def _select_best_segments(self, max_segments):
    """Seleziona i migliori segmenti con fallback"""
    if not self.segment_candidates:
        print("⚠️ Nessun segmento soddisfa i criteri rigorosi.")
        print("🔄 Rianalizzando con criteri più permissivi...")
        
        # Rianalizza con criteri più permissivi
        self.min_correlation = 0.0
        self.max_correlation = 1.0
        self.min_diversity = 0.0
        
        # Riprova l'analisi
        self.segment_candidates = []
        self._analyze_segment_quality(max_segments * 2)
        
        if not self.segment_candidates:
            # Ultimo fallback: usa tutti i segmenti disponibili
            print("🚨 Usando tutti i segmenti disponibili senza filtri di qualità")
            self._create_fallback_segments(max_segments)
            return
    
    # Continua con la logica normale...
    self.segment_candidates.sort(key=lambda x: x['quality_score'], reverse=True)
    selected_segments = self.segment_candidates[:max_segments]
    self.segment_list = [(s['fileA'], s['fileB'], s['start']) for s in selected_segments]
    
    # Statistiche
    correlations = [s['correlation'] for s in selected_segments]
    diversities = [s['diversity_score'] for s in selected_segments]
    
    print(f"✅ Selezionati {len(self.segment_list)} segmenti")
    print(f"📊 Correlazione media: {np.mean(correlations):.4f} ± {np.std(correlations):.4f}")
    print(f"📊 Diversità media: {np.mean(diversities):.4f} ± {np.std(diversities):.4f}")

def _create_fallback_segments(self, max_segments):
    """Crea segmenti senza filtri di qualità come fallback"""
    total_segments = 0
    
    for fileA, fileB in zip(self.noisy_A, self.noisy_B):
        if total_segments >= max_segments:
            break
            
        waveform, sr = torchaudio.load(fileA)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        total_len = waveform.shape[1]
        start = self.skip_start
        cutoff = total_len - 12 * SAMPLE_RATE
        
        while start + self.segment_length <= cutoff and total_segments < max_segments:
            self.segment_list.append((fileA, fileB, start))
            total_segments += 1
            start += self.segment_length
    
    print(f"📦 Fallback: creati {len(self.segment_list)} segmenti senza filtri")

def compute_spectral_diversity_score(stft_input, stft_target):
    """
    Calcola un punteggio di diversità spettrale tra input e target
    """
    # Magnitudine degli spettrogrammi
    mag_input = torch.sqrt(stft_input[..., 0]**2 + stft_input[..., 1]**2)
    mag_target = torch.sqrt(stft_target[..., 0]**2 + stft_target[..., 1]**2)
    
    # 1. Diversità di energia per banda di frequenza
    energy_input = torch.sum(mag_input**2, dim=-1)  # Energia per frequenza
    energy_target = torch.sum(mag_target**2, dim=-1)
    
    energy_diff = torch.abs(energy_input - energy_target)
    energy_diversity = torch.mean(energy_diff / (energy_input + energy_target + 1e-8))
    
    # 2. Diversità temporale (variazioni nel tempo)
    temporal_var_input = torch.var(mag_input, dim=-1)
    temporal_var_target = torch.var(mag_target, dim=-1)
    temporal_diversity = torch.mean(torch.abs(temporal_var_input - temporal_var_target))
    
    # 3. Diversità di fase
    phase_input = torch.atan2(stft_input[..., 1], stft_input[..., 0])
    phase_target = torch.atan2(stft_target[..., 1], stft_target[..., 0])
    phase_diff = torch.abs(phase_input - phase_target)
    phase_diff = torch.min(phase_diff, 2*torch.pi - phase_diff)  # Differenza circolare
    phase_diversity = torch.mean(phase_diff)
    
    # Score combinato
    diversity_score = (
        0.4 * energy_diversity + 
        0.3 * temporal_diversity + 
        0.3 * phase_diversity
    ).item()
    
    return {
        'diversity_score': diversity_score,
        'energy_diversity': energy_diversity.item(),
        'temporal_diversity': temporal_diversity.item(),
        'phase_diversity': phase_diversity.item()
    }


# Dichiarazioni del Dataset e del Dataloader

In [6]:

import os
import torch
import torchaudio
import numpy as np
from pathlib import Path

def save_filtered_segments_to_db(dataset, base_output_dir, split_name, sample_rate=44100):
    """
    Salva i segmenti filtrati in una struttura di cartelle train/test con input/target
    
    Args:
        dataset: Il tuo QualityFilteredNoise2NoiseDataset
        base_output_dir: Cartella base dove salvare il nuovo database
        split_name: 'train' o 'test'
        sample_rate: Frequenza di campionamento
    """
    # Crea le cartelle
    input_dir = Path(base_output_dir) / split_name / 'input'
    target_dir = Path(base_output_dir) / split_name / 'target'
    
    input_dir.mkdir(parents=True, exist_ok=True)
    target_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"💾 Salvando {len(dataset.segment_list)} segmenti filtrati in {base_output_dir}/{split_name}/")
    
    for idx, (fileA, fileB, start_sample) in enumerate(dataset.segment_list):
        # Carica i segmenti usando la funzione del dataset
        segment_A = dataset.load_segment(fileA, start_sample)
        segment_B = dataset.load_segment(fileB, start_sample)
        
        # Converti da tensor a numpy se necessario
        if isinstance(segment_A, torch.Tensor):
            segment_A = segment_A.cpu().numpy()
        if isinstance(segment_B, torch.Tensor):
            segment_B = segment_B.cpu().numpy()
        
        # Assicurati che siano mono
        if segment_A.ndim > 1:
            segment_A = segment_A.squeeze()
        if segment_B.ndim > 1:
            segment_B = segment_B.squeeze()
        
        # Salva i file audio
        input_path = input_dir / f'segment_{idx:05d}.wav'
        target_path = target_dir / f'segment_{idx:05d}.wav'
        
        torchaudio.save(str(input_path), torch.from_numpy(segment_A).unsqueeze(0), sample_rate)
        torchaudio.save(str(target_path), torch.from_numpy(segment_B).unsqueeze(0), sample_rate)
        
        if (idx + 1) % 100 == 0:
            print(f"  Salvati {idx + 1}/{len(dataset.segment_list)} segmenti...")
    
    print(f"✅ Completato! Salvati {len(dataset.segment_list)} segmenti in {base_output_dir}/{split_name}/")
    
    # Salva anche le statistiche di qualità
    save_quality_stats(dataset, base_output_dir, split_name)

def save_quality_stats(dataset, base_output_dir, split_name):
    """Salva le statistiche di qualità dei segmenti"""
    stats_file = Path(base_output_dir) / f'{split_name}_quality_stats.txt'
    
    if hasattr(dataset, 'segment_candidates') and dataset.segment_candidates:
        correlations = [s['correlation'] for s in dataset.segment_candidates[:len(dataset.segment_list)]]
        diversities = [s['diversity_score'] for s in dataset.segment_candidates[:len(dataset.segment_list)]]
        
        with open(stats_file, 'w') as f:
            f.write(f"Statistiche di Qualità - {split_name.upper()}\n")
            f.write("="*50 + "\n")
            f.write(f"Numero segmenti: {len(dataset.segment_list)}\n")
            f.write(f"Correlazione media: {np.mean(correlations):.4f} ± {np.std(correlations):.4f}\n")
            f.write(f"Diversità media: {np.mean(diversities):.4f} ± {np.std(diversities):.4f}\n")
            f.write(f"Range correlazione: {min(correlations):.4f} - {max(correlations):.4f}\n")
            f.write(f"Range diversità: {min(diversities):.4f} - {max(diversities):.4f}\n")
        
        print(f"📊 Statistiche salvate in {stats_file}")


In [7]:
class QualityFilteredNoise2NoiseDataset(Dataset):
    def __init__(self, noisy_file_set_A, noisy_file_set_B, n_fft=1024, hop_length=256,
                 min_correlation_threshold=0.05, max_correlation_threshold=0.15, 
                 min_diversity_score=0.2, max_segments=5000):
        super().__init__()
        
        self.noisy_A = sorted(noisy_file_set_A)
        self.noisy_B = sorted(noisy_file_set_B)
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.window = window
        
        self.segment_length = 165000
        self.skip_start = 6 * SAMPLE_RATE
        
        # Parametri di qualità
        self.min_correlation = min_correlation_threshold
        self.max_correlation = max_correlation_threshold
        self.min_diversity = min_diversity_score
        
        print("🔍 Analizzando qualità dei segmenti...")
        self.segment_candidates = []
        self._analyze_segment_quality(max_segments * 3)
        self._select_best_segments(max_segments)
    
    def _analyze_segment_quality(self, max_analysis):
        """Analizza la qualità di tutti i segmenti candidati"""
        total_analyzed = 0
        
        for fileA, fileB in zip(self.noisy_A, self.noisy_B):
            if total_analyzed >= max_analysis:
                break
                
            waveform, sr = torchaudio.load(fileA)
            if sr != SAMPLE_RATE:
                waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            total_len = waveform.shape[1]
            start = self.skip_start
            cutoff = total_len - 12 * SAMPLE_RATE
            
            while start + self.segment_length <= cutoff and total_analyzed < max_analysis:
                quality_metrics = self._compute_segment_quality(fileA, fileB, start)
                
                if quality_metrics is not None:
                    self.segment_candidates.append({
                        'fileA': fileA,
                        'fileB': fileB,
                        'start': start,
                        **quality_metrics
                    })
                
                total_analyzed += 1
                start += self.segment_length
            
            if total_analyzed % 50 == 0:
                print(f"Analizzati {total_analyzed} segmenti...")
    
    def _compute_segment_quality(self, fileA, fileB, start_sample):
        """Calcola metriche di qualità per un segmento"""
        try:
            # Carica segmenti
            x1 = self.load_segment(fileA, start_sample)
            x2 = self.load_segment(fileB, start_sample)
            
            # Converti in STFT
            x1_stft = torch.stft(x1, n_fft=self.n_fft, hop_length=self.hop_length,
                               window=self.window.to(x1.device), normalized=True, return_complex=True)
            x1_stft = torch.view_as_real(x1_stft)
            
            x2_stft = torch.stft(x2, n_fft=self.n_fft, hop_length=self.hop_length,
                               window=self.window.to(x1.device), normalized=True, return_complex=True)
            x2_stft = torch.view_as_real(x2_stft)
            
            # Analizza correlazione rumori
            correlation_metrics = analyze_noise_correlation(x1_stft, x2_stft)
            
            # Analizza diversità spettrale
            diversity_metrics = compute_spectral_diversity_score(x1_stft, x2_stft)
            
            # Criteri molto permissivi per evitare il fallimento
            correlation_ok = abs(correlation_metrics['correlation']) <= 1.0  # Accetta tutto
            diversity_ok = diversity_metrics['diversity_score'] >= 0.0       # Accetta tutto
            
            if correlation_ok and diversity_ok:
                return {
                    'correlation': correlation_metrics['correlation'],
                    'diversity_score': diversity_metrics['diversity_score'],
                    'quality_score': diversity_metrics['diversity_score'] - abs(correlation_metrics['correlation']),
                    'energy_balance': abs(correlation_metrics['noise_input_power'] - correlation_metrics['noise_target_power'])
                }
            else:
                return None
                
        except Exception as e:
            print(f"Errore nel calcolo qualità: {e}")
            return None
    
    def _select_best_segments(self, max_segments):
        """Seleziona i migliori segmenti con fallback"""
        if not self.segment_candidates:
            print("⚠️ Nessun segmento soddisfa i criteri rigorosi.")
            print("🔄 Rianalizzando con criteri più permissivi...")
            
            # Rianalizza con criteri più permissivi
            self.min_correlation = 0.0
            self.max_correlation = 1.0
            self.min_diversity = 0.0
            
            # Riprova l'analisi
            self.segment_candidates = []
            self._analyze_segment_quality(max_segments * 2)
            
            if not self.segment_candidates:
                # Ultimo fallback: usa tutti i segmenti disponibili
                print("🚨 Usando tutti i segmenti disponibili senza filtri di qualità")
                self._create_fallback_segments(max_segments)
                return
        
        # Continua con la logica normale...
        self.segment_candidates.sort(key=lambda x: x['quality_score'], reverse=True)
        selected_segments = self.segment_candidates[:max_segments]
        self.segment_list = [(s['fileA'], s['fileB'], s['start']) for s in selected_segments]
        
        # Statistiche
        correlations = [s['correlation'] for s in selected_segments]
        diversities = [s['diversity_score'] for s in selected_segments]
        
        print(f"✅ Selezionati {len(self.segment_list)} segmenti")
        print(f"📊 Correlazione media: {np.mean(correlations):.4f} ± {np.std(correlations):.4f}")
        print(f"📊 Diversità media: {np.mean(diversities):.4f} ± {np.std(diversities):.4f}")

    def _create_fallback_segments(self, max_segments):
        """Crea segmenti senza filtri di qualità come fallback"""
        self.segment_list = []  # Inizializza la lista
        total_segments = 0
        
        for fileA, fileB in zip(self.noisy_A, self.noisy_B):
            if total_segments >= max_segments:
                break
                
            waveform, sr = torchaudio.load(fileA)
            if sr != SAMPLE_RATE:
                waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            total_len = waveform.shape[1]
            start = self.skip_start
            cutoff = total_len - 12 * SAMPLE_RATE
            
            while start + self.segment_length <= cutoff and total_segments < max_segments:
                self.segment_list.append((fileA, fileB, start))
                total_segments += 1
                start += self.segment_length
        
        print(f"📦 Fallback: creati {len(self.segment_list)} segmenti senza filtri")
    
    def load_segment(self, file, start_sample):
        waveform, sr = torchaudio.load(file)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        waveform = waveform[:, start_sample:start_sample + self.segment_length]
        return waveform
    
    
    def __len__(self):
        return len(self.segment_list)
    
    def __getitem__(self, index):
        fileA, fileB, start_sample = self.segment_list[index]
        x1 = self.load_segment(fileA, start_sample)
        x2 = self.load_segment(fileB, start_sample)

        x1_stft = torch.stft(x1, n_fft=self.n_fft, hop_length=self.hop_length,
                           window=self.window.to(x1.device), normalized=True, return_complex=True)
        x1_stft = torch.view_as_real(x1_stft)

        x2_stft = torch.stft(x2, n_fft=self.n_fft, hop_length=self.hop_length,
                           window=self.window.to(x1.device), normalized=True, return_complex=True)
        x2_stft = torch.view_as_real(x2_stft)

        return x1_stft, x2_stft



files_noise_input = sorted(list(TRAIN_INPUT_DIR.rglob("*.wav")))
files_noise_target = sorted(list(TRAIN_TARGET_DIR.rglob("*.wav")))
test_noisy_files = sorted(list(TEST_NOISY_DIR.rglob('*.wav')))
test_clean_files = sorted(list(TEST_CLEAN_DIR.rglob('*.wav')))

print("No. of Training files:",len(files_noise_input))
print("No. of Test files:",len(test_noisy_files))

noise2noise_dataset = QualityFilteredNoise2NoiseDataset(
    files_noise_input, files_noise_target, n_fft, hop_length,
    min_correlation_threshold=0.0,   # Accetta qualsiasi correlazione bassa
    max_correlation_threshold=0.3,   # Range più ampio
    min_diversity_score=0.1,         # Soglia più bassa
    max_segments=10000
)



# Test set: non normalizza niente
# Test set con filtri di qualità (opzionale)
test_dataset = QualityFilteredNoise2NoiseDataset(
    test_noisy_files, test_clean_files, n_fft, hop_length,
    min_correlation_threshold=0.0,   # Più permissivo per il test
    max_correlation_threshold=1.0,   
    min_diversity_score=0.0,         
    max_segments=1000  # Meno segmenti per il test
)
train_loader = DataLoader(noise2noise_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# For testing purpose
test_loader_single_unshuffled = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Dopo aver creato i dataset filtrati
print("🔄 Salvando i segmenti filtrati in un nuovo database...")

# Definisci la cartella di output
new_db_path = "filtered_noise2noise_db_n1_noise"

# Salva il training set
save_filtered_segments_to_db(
    dataset=noise2noise_dataset,
    base_output_dir=new_db_path,
    split_name='train',
    sample_rate=SAMPLE_RATE
)

# Salva il test set
save_filtered_segments_to_db(
    dataset=test_dataset,
    base_output_dir=new_db_path,
    split_name='test',
    sample_rate=SAMPLE_RATE
)

print("🎉 Database filtrato creato con successo!")


No. of Training files: 65
No. of Test files: 21
🔍 Analizzando qualità dei segmenti...
✅ Selezionati 5631 segmenti
📊 Correlazione media: 1.0000 ± 0.0000
📊 Diversità media: 0.3790 ± 0.1894
🔍 Analizzando qualità dei segmenti...
✅ Selezionati 1000 segmenti
📊 Correlazione media: 1.0000 ± 0.0000
📊 Diversità media: 0.2666 ± 0.1440
🔄 Salvando i segmenti filtrati in un nuovo database...
💾 Salvando 5631 segmenti filtrati in filtered_noise2noise_db_n1_noise/train/
  Salvati 100/5631 segmenti...
  Salvati 200/5631 segmenti...
  Salvati 300/5631 segmenti...
  Salvati 400/5631 segmenti...
  Salvati 500/5631 segmenti...
  Salvati 600/5631 segmenti...
  Salvati 700/5631 segmenti...
  Salvati 800/5631 segmenti...
  Salvati 900/5631 segmenti...
  Salvati 1000/5631 segmenti...
  Salvati 1100/5631 segmenti...
  Salvati 1200/5631 segmenti...
  Salvati 1300/5631 segmenti...
  Salvati 1400/5631 segmenti...
  Salvati 1500/5631 segmenti...
  Salvati 1600/5631 segmenti...
  Salvati 1700/5631 segmenti...
  Salva