# MUSDB18 Dataset Preparation for Source Separation

This notebook:
1. Loads MUSDB18 dataset
2. Converts stereo to mono
3. Normalizes audio
4. Extracts 3-second segments
5. Calculates spectrograms
6. Saves as PyTorch tensors

In [1]:
import musdb
import numpy as np
import torch
import librosa
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
 
# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f718403bd30>

In [2]:
# Configuration parameters
SAMPLE_RATE = 44100
SEGMENT_DURATION = 0.9  # seconds
SEGMENT_SAMPLES = int(SAMPLE_RATE * SEGMENT_DURATION)
N_FFT = 2048
HOP_LENGTH = 512
SAVE_DIR = '../data/processed/'


In [None]:
def load_musdb(subset='train'):
    """Load MUSDB18 dataset."""
    
    return musdb.DB(root="../data/raw/", subsets=[subset])[:2]

def stereo_to_mono(audio):
    """Convert stereo audio to mono by averaging channels."""
    return np.mean(audio, axis=1) if audio.ndim > 1 else audio

def normalize_audio(audio):
    """Normalize audio to [-1, 1] range."""
    return audio / np.max(np.abs(audio))

def extract_segments(waveform, sample_rate=44100, segment_duration=3):
    """
    Extrait des segments de durée fixe d'une forme d'onde.
    
    Arguments:
        waveform (np.array): La forme d'onde à segmenter (1D array)
        sample_rate (int): Fréquence d'échantillonnage en Hz (par défaut: 44100 Hz)
        segment_duration (float): Durée des segments en secondes (par défaut: 3 secondes)
    
    Retourne:
        np.array: Un tableau 2D de taille (n_segments, n_echantillons) contenant les segments
    """
    # Calcul du nombre d'échantillons par segment
    samples_per_segment = int(sample_rate * segment_duration)
    
    # Calcul du nombre total de segments
    total_samples = len(waveform)
    n_segments = (total_samples + samples_per_segment - 1) // samples_per_segment
    
    # Création d'un tableau pour stocker les segments
    segments = np.zeros((n_segments, samples_per_segment))
    
    # Extraction des segments
    for i in range(n_segments):
        start_idx = i * samples_per_segment
        end_idx = min(start_idx + samples_per_segment, total_samples)
        
        # Copie des échantillons dans le segment
        segment_length = end_idx - start_idx
        segments[i, :segment_length] = waveform[start_idx:end_idx]
        
    return segments

def compute_spectrogram(audio):
    """Compute magnitude spectrogram using PyTorch.
    
    Args:
        audio: 1D numpy array or tensor of shape (segment_samples,)
              where segment_samples = SAMPLE_RATE * SEGMENT_DURATION
    
    Returns:
        numpy.ndarray: Magnitude spectrogram of shape (n_fft//2 + 1, n_frames) where n_frames = (segment_samples - n_fft) // hop_length + 1
    """
    # Convert to torch tensor if not already
    if not isinstance(audio, torch.Tensor):
        audio = torch.FloatTensor(audio).to('cuda')
    
    # Compute STFT
    stft = torch.stft(
        audio,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        return_complex=True
    )
    
    # Convert to magnitude spectrogram
    magnitudes = torch.abs(stft)
    
    return magnitudes.cpu().numpy()

In [8]:
def process_dataset(subset='train'):
    """Process the dataset and return spectrograms for mixture and sources."""
    mus = load_musdb(subset)
    all_mix_specs = []  # List to store all segment spectrograms
    all_source_specs = []  # List to store all source spectrograms
    
    print(f"Found {len(mus)} tracks in {subset} set")
    
    # Main progress bar for tracks
    pbar_tracks = tqdm(mus, desc=f"Processing {subset} tracks", unit="track")
    
    for track_idx, track in enumerate(pbar_tracks):
        # Update progress bar description with current track
        pbar_tracks.set_description(f"Processing {track.name}")
        
        # Process mixture
        track.audio = track.audio.T
        mix_audio = stereo_to_mono(track.audio.T)
        mix_audio = normalize_audio(mix_audio)
        mix_segments = extract_segments(mix_audio, SEGMENT_SAMPLES)
        print("mix_segments :", mix_segments.shape)
        
        # Process sources (vocals, drums, bass, other)
        sources_audio = {}
        
        # Process sources first
        for source in tqdm(['vocals', 'drums', 'bass', 'other'], desc="Processing sources", leave=False):
            source_audio = stereo_to_mono(track.targets[source].audio.T)
            source_audio = normalize_audio(source_audio)
            sources_audio[source] = extract_segments(source_audio, SEGMENT_SAMPLES)
        
        # Compute spectrograms for each segment
        for i in tqdm(range(len(mix_segments)), desc="Computing spectrograms", leave=False):
            # Process mixture spectrogram
            mix_spec = compute_spectrogram(mix_segments[i])
            all_mix_specs.append(mix_spec)
            print(len(all_mix_specs))
            
            # Process source spectrograms
            segment_sources = []
            for source in ['vocals', 'drums', 'bass', 'other']:
                source_spec = compute_spectrogram(sources_audio[source][i])
                segment_sources.append(source_spec)
            
            all_source_specs.append(np.stack(segment_sources))

    
    # Convert to PyTorch tensors with proper dimensions
    # Shape: (n_segments, freq_bins, time_frames)
    X = torch.FloatTensor(np.stack(all_mix_specs))
    # Shape: (n_segments, n_sources, freq_bins, time_frames)
    y = torch.FloatTensor(np.stack(all_source_specs))
    
    # Add a channel dimension for the mixture
    X = X.unsqueeze(1)  # Shape: (n_segments, 1, freq_bins, time_frames)
    
    print(f"\nFinal dataset shapes:")
    print(f"X: {X.shape}")
    print(f"y: {y.shape}")
    
    return X, y

In [9]:
# Process training data
print("Processing training data...")
X_train, y_train = process_dataset('train')

"""
# Process test data
print("Processing test data...")
X_test, y_test = process_dataset('test')

# Save processed data
torch.save({
    'X_train': X_train,
    'y_train': y_train,
    'X_test': X_test,
    'y_test': y_test
}, os.path.join(SAVE_DIR, 'processed_musdb18.pt'))

print(f"Saved processed data with shapes:")
print(f"X_train: {X_train.shape}")
print(f"y_train: {y_train.shape}")
print(f"X_test: {X_test.shape}")
print(f"y_test: {y_test.shape}")
"""

Processing training data...
Found 2 tracks in train set


Processing A Classic Education - NightOwl:   0%|          | 0/2 [00:00<?, ?track/s]

mix_segments : (64, 119070)


Processing A Classic Education - NightOwl:   0%|          | 0/2 [00:07<?, ?track/s]


torch.Size([1025, 233])
1
torch.Size([1025, 233])
torch.Size([1025, 233])
torch.Size([1025, 233])
torch.Size([1025, 233])
torch.Size([1025, 233])
2


IndexError: index 1 is out of bounds for axis 0 with size 1