# MUSDB18 Dataset Data Augmentation 
| Name         | Surname    | ID        |
|--------------|------------|-----------|
| ABOUELAZM    | Youssef    | 10960436  |
| BINGLING     | Wu         | 11105141  |
| GARCIA       | Adrian     | 10975956  |
| OUALI        | Ernest     | 10984484  |

This notebook performs data augmentation on MUSDB18 dataset.

<b> Features of the script:</b>

- Extract tracks from dataset
- Switch to mono
- Apply various audio augmentations (pitch shifting, time stretching, compression, and reverb)
  - May be modified/improved later
- Saves augmented data to a given folder

## Configuration and setup

### Import Libraries

In [None]:
# System and file operations
import os
import warnings
warnings.filterwarnings('ignore')

# Numerical and scientific computing
import numpy as np
import random 

# Audio processing and manipulation
import librosa  # Audio analysis library
import soundfile as sf  # Audio file reading/writing
import musdb  # MUSDB18 dataset handler for music source separation

# Data visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Visualization styling
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Jupyter notebook display utilities
from IPython.display import Audio, display  # For playing and displaying audio in notebooks

### Config Class

In [None]:
# Configuration parameters
class Config:
    # Paths
    MUSDB18_PATH = "/Users/agathe/Desktop/musdb18"          
    # Update this path to your MUSDB18 dataset

    OUTPUT_DIR_COHERENT_MIX = "/Users/agathe/Desktop/Augmented_Data_Coherent_Mix"
    # Update this path to your desired output directory for coherent mixes
    
    OUTPUT_DIR_INCOHERENT_MIX = "/Users/agathe/Desktop/Augmented_Data_Incoherent_Mix"
    # Update this path to your desired output directory for incoherent mixes
    
    # Audio parameters
    SAMPLE_RATE = 44100  # Standard sample rate for MUSDB18
    
    # Augmentation ranges
    PITCH_SHIFT_RANGE = (-2, 2)             # Pitch shift in semitones
    TIME_STRETCH_RANGE = (0.8, 1.2)         # Time stretching factor
    # NOISE_LEVEL_RANGE = (0.001, 0.01)       # Noise level range not used in this example
    GAIN_RANGE = (0.7, 1.3)                 # Gain adjustment range
    
    # Dynamic range compression parameters
    COMPRESSION_THRESHOLD_RANGE = (0.3, 0.7)  # Threshold for compression (0.0 to 1.0)
    COMPRESSION_RATIO_RANGE = (2.0, 6.0)      # Compression ratio (higher = more compression)
    
    # Reverb parameters
    REVERB_ROOM_SIZE_RANGE = (0.2, 0.8)       # Room size parameter (0-1)
    REVERB_DAMPING_RANGE = (0.2, 0.8)         # Damping parameter (0-1)

config = Config()

# Create output directory if it doesn't exist
os.makedirs(config.OUTPUT_DIR_COHERENT_MIX, exist_ok=True)
os.makedirs(config.OUTPUT_DIR_INCOHERENT_MIX, exist_ok=True)
print(f"Output directory for coherent data mix generation: {config.OUTPUT_DIR_COHERENT_MIX}")
print(f"Output directory for incoherent data mix generation: {config.OUTPUT_DIR_INCOHERENT_MIX}")

### Dataset Extraction

In [None]:
MUSDB_PATH = "/Users/agathe/Desktop/musdb18"

# Extracting the whole set
mus = musdb.DB(root=MUSDB_PATH, is_wav=False, download=True)
# is_wav=False means that the dataset is .mp4

mus_train = musdb.DB(root=MUSDB_PATH, is_wav=False, download=True, subsets="train", split='train')
mus_valid = musdb.DB(root=MUSDB_PATH, is_wav=False, download=True, subsets="train", split='valid')
mus_test  = musdb.DB(root=MUSDB_PATH, is_wav=False, download=True, subsets="test")

In [None]:
print(f"Whole dataset loaded with {len(mus)} tracks.")
print(f"Training set loaded with {len(mus_train)} tracks.")
print(f"Validation set loaded with {len(mus_valid)} tracks.")
print(f"Test set loaded with {len(mus_test)} tracks.")

# Print the first track's name and type
print(f"First track: {mus[0].name}, Type: {type(mus[0])}")
# Print the type of the elelemnts within the mus[0]
print(f"Type of elements in the first track: {type(mus[0].audio)}")


In [None]:
# Print all the different sr in mus.tracks
sr_list = [track.rate for track in mus.tracks]
print(f"\nUnique counts of sample rates in mus.tracks: {np.unique(sr_list)}")
print(f"Sampling rates are matching: {config.SAMPLE_RATE == np.unique(sr_list)[0]}")  

### Track object visualization

In [None]:
def extract_track_by_index(mus, k, to_mono=True):
    """
    Extract the track at index k from the MUSDB18 dataset using musdb library,
    with option to convert stereo audio to mono.
    
    MUSDB18 dataset typically contains stems for vocals, drums, bass, and other instruments,
    along with the full mixture. Each track is organized with:
    - track.audio: The full audio mixture (typically stereo)
    - track.sources: Dictionary of individual source stems (vocals, drums, bass, other)
    - track.targets: Processed version of stems used for evaluation

    Args:
        mus (musdb.DB): An instance of the musdb.DB loader containing the MUSDB18 dataset.
                        Can be the full dataset, train set, or test set.
        k (int): Index of the track to extract from the dataset (zero-based).
        to_mono (bool): If True, converts all audio in the track to mono by averaging channels.
                        This affects the mixture audio, all sources, and all targets.
                        Default is True.

    Returns:
        musdb.Track: The musdb Track object at index k, with audio converted to mono if specified.
                     The returned track maintains all its original properties but with modified audio arrays.
    
    Raises:
        IndexError: If the requested track index is out of range.
    """
    
    # Validate that the track index is within the dataset bounds
    if k < 0 or k >= len(mus.tracks):
        raise IndexError(f"Track index {k} is out of range (0, {len(mus.tracks)-1})")
    
    # Get the track at the specified index
    track = mus.tracks[k]
    
    # If to_mono is True, convert all audio components to mono
    if to_mono:
        # Convert main mixture audio to mono
        # Audio in MUSDB18 has shape (n_samples, n_channels) where n_channels=2 for stereo
        if track.audio.ndim > 1 and track.audio.shape[1] == 2:
            # librosa.to_mono expects shape (n_channels, n_samples), so we transpose
            # to_mono performs averaging across channels: (left+right)/2
            mono_audio = librosa.to_mono(track.audio.T)  # Transpose to match librosa's expected format
            track.audio = mono_audio  # Replace stereo with mono version (shape becomes 1D array of n_samples)
        
        # Convert all individual sources/stems to mono (vocals, drums, bass, other)
        # Sources contain the raw, unprocessed stems
        for source_name in track.sources:
            source = track.sources[source_name]
            if source.audio.ndim > 1 and source.audio.shape[1] == 2:
                mono_source = librosa.to_mono(source.audio.T)
                source.audio = mono_source  # Replace the stereo stem with mono version

        # Also convert targets if they exist and are different from sources
        # Targets are the processed versions of stems used for evaluation
        for target_name in track.targets:
            target = track.targets[target_name]
            if target.audio.ndim > 1 and target.audio.shape[1] == 2:
                mono_target = librosa.to_mono(target.audio.T)
                target.audio = mono_target  # Replace the stereo target with mono version

    # Return the track (either original stereo or converted to mono)
    return track

In [None]:
track = extract_track_by_index(mus_train, 4)  # Extract the first track

audio_data, sr = track.audio, track.rate
audio_data = audio_data / np.max(np.abs(audio_data))  # Normalize audio data to [-1, 1] range

print(f"Track name: {track.name}")
print(f"Sample rate: {sr}")

print(f"audio_data shape: {audio_data.shape}")
print(f"audio_data duration: {audio_data.shape[0] / sr:.2f} seconds")
print(f"Details of sources: {track.sources}")
print(f"Details of targets: {track.targets}\n")


print(f"Sources names: {mus.sources_names}")
print(f"Targets names: {mus.targets_names}")

In [None]:
def plot_stems(track, sr=track.rate, max_duration=track.audio.shape[0] / track.rate):
    """
    Plot waveforms and audio players for each source stem in a MUSDB18 track.
    
    Parameters:
        track: musdb.Track
        sr: sample rate (optional)
        max_duration: max seconds to display/play per stem
                     if None, uses the whole track duration
    """
    if sr is None:
        sr = track.rate

    sources = track.targets.keys()
    num_sources = len(sources)
    
    fig, axes = plt.subplots(num_sources, 1, figsize=(14, 2 * num_sources), sharex=True)

    print(f"Showing up to {max_duration} seconds of {track.name} for: {', '.join(sources)}")
    print()

    for i, stem_name in enumerate(sources):
        audio = track.targets[stem_name].audio
        audio = librosa.to_mono(audio.T)  # make mono for simplicity
        audio = audio[:int(max_duration * sr)]  # crop to max_duration
        
        librosa.display.waveshow(audio, sr=sr, ax=axes[i])
        axes[i].set_title(stem_name.capitalize())
        axes[i].set_ylabel("Amplitude")
        axes[i].grid(True)
    
    axes[-1].set_xlabel("Time (s)")
    plt.tight_layout()
    plt.title(track.name)
    plt.show()

    # Audio players
    # Looping through all sources except linear mixture
    for stem_name in sources:
        audio = track.targets[stem_name].audio
        audio = librosa.to_mono(audio.T)
        audio = audio[:int(max_duration * sr)]
        print(f"▶️ {stem_name.capitalize()}")
        display(Audio(audio, rate=sr))

In [None]:
plot_stems(track)

### Audio Extraction function

In [None]:
def extract_audio(track, stem='mixture', duration=10.0, offset = 30.0):
    """
    Extract audio of a specific stem from a MUSDB18 track, starting from a random offset.
    
    Args:
        track (musdb.Track): A musdb18 track object (mono or stereo)
        stem (str): Which stem to extract:
            - 'mixture': Extract from the full mixture audio
            - 'vocals': Extract from the vocals stem
            - 'drums': Extract from the drums stem
            - 'bass': Extract from the bass stem
            - 'other': Extract from the other stem
        duration (float): Duration in seconds to extract
    
    Returns:
        np.ndarray: Audio segment extracted from the specified stem at a random offset
    
    Raises:
        ValueError: If the specified stem is not available in the track
    """
    # Get sample rate
    sr = track.rate
    
    # Validate the stem parameter
    if stem == 'mixture':
        audio = track.audio
    elif stem in track.sources:
        audio = track.sources[stem].audio
    else:
        valid_stems = ['mixture'] + list(track.sources.keys())
        raise ValueError(f"Invalid stem '{stem}'. Valid options are: {', '.join(valid_stems)}")
    
    # Ensure audio is mono (if not already)
    if audio.ndim > 1:
        audio = librosa.to_mono(audio.T)
    
    # Calculate maximum possible offset to ensure we can extract the full duration
    total_duration = len(audio) / sr
    max_offset = total_duration - duration
    
    # If the selected offset is greater than the maximum offset, raise an error
    if offset > max_offset:
        raise ValueError(f"Offset {offset} seconds exceeds maximum possible offset of {max_offset:.2f} seconds for this track.")
    
    # Convert offset to samples
    offset_samples = int(offset * sr)
    
    # Extract the audio segment
    end_sample = offset_samples + int(duration * sr)
    audio = audio[offset_samples:end_sample]
    
    return audio

In [None]:
def plot_audio_with_melspec(audio, sr=44100, n_fft=2048, hop_length=512, n_mels=128, title="Audio Waveform and Mel Spectrogram"):
    """
    Plot the waveform of an audio signal.
    
    Parameters:
        audio (np.ndarray): Audio signal array (1D for mono, 2D for stereo)
        sr (int): Sample rate of the audio
        title (str): Super Title for the plot
    """
    if audio.ndim > 1:
        raise ValueError("Audio must be a 1D array (mono) for waveform plotting. Make sure you used extract_audio().")
    
    
    # Create figure with two subplots (stacked vertically)
    fig, ax = plt.subplots(2, 1, figsize=(14, 8), sharex=True)
    fig.suptitle(title, fontsize=16)
    
    # Plot waveform on top subplot
    librosa.display.waveshow(audio, sr=sr, ax=ax[0])
    ax[0].set_title("Waveform")
    ax[0].set_ylabel("Amplitude")
    ax[0].grid(True)
    
    # Compute mel spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=audio, 
        sr=sr, 
        n_fft=n_fft, 
        hop_length=hop_length,
        n_mels=n_mels
    )
    
    # Convert to dB scale for better visualization
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    
    # Plot mel spectrogram on bottom subplot
    img = librosa.display.specshow(
        mel_spec_db, 
        x_axis='time', 
        y_axis='mel', 
        sr=sr, 
        hop_length=hop_length, 
        ax=ax[1]
    )
    ax[1].set_title("Mel Spectrogram")
    ax[1].set_xlabel("Time (s)")
    ax[1].set_ylabel("Frequency (mel)")
    
    # Add colorbar
    fig.colorbar(img, ax=ax[1], format='%+2.0f dB')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92)  # Make room for the title
    plt.show()

In [None]:
my_track = extract_track_by_index(mus_train, 8) 

my_audio = extract_audio(my_track, stem='vocals', duration=10.0, offset=30.0)
my_audio = my_audio / np.max(np.abs(my_audio))  # Normalize audio to [-1, 1]

print(f"Extracted audio shape: {my_audio.shape}")
print(f"Extracted audio duration: {my_audio.shape[0] / my_track.rate:.2f} seconds")
# Play the extracted audio
display(Audio(my_audio, rate=my_track.rate))

plot_audio_with_melspec(my_audio, sr=my_track.rate)

## Data Augmentation Functions

Implementation of various audio augmentation techniques similar to those available in Scaper.

### Pitch Shift

In [None]:
def pitch_shift_audio(audio, sr, n_steps=0, normalize=True):
    """
    Apply pitch shifting to audio data in a numpy array.
    
    Args:
        audio (np.ndarray): Audio data as a mono numpy array, assumed to be 1D.
        sr (int): Sample rate of the audio data
        n_steps (float): Number of semitones to shift
        normalize (bool): If True, normalizes the audio to the range [-1, 1] after pitch shifting.
                          Default is True.
    
    Returns:
        np.ndarray: Pitch-shifted audio array
                         
    Raises:
        ValueError: If audio is not a 1D numpy array (mono)
    """
    # Check if audio is mono (1D array)
    if audio.ndim > 1:
        raise ValueError("Expected mono audio array (1D), but got multi-dimensional audio.")
    audio = audio / np.max(np.abs(audio))  # Normalize audio to [-1, 1] before processing
    
    # Apply pitch shifting to the mono audio array
    shifted_audio = librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)

    if normalize:
        # Normalize the audio to the range [-1, 1]
        shifted_audio = shifted_audio / np.max(np.abs(shifted_audio))
    return shifted_audio

In [None]:
n_steps=6
ps_my_audio = pitch_shift_audio(my_audio, sr=my_track.rate, n_steps=n_steps)
print(f"Pitch-shifted audio shape: {ps_my_audio.shape}")

# Play the pitch-shifted audio
display(Audio(ps_my_audio, rate=my_track.rate))
plot_audio_with_melspec(ps_my_audio, sr=my_track.rate, title=f"Pitch-Shifted Audio Waveform, n_steps={n_steps}")

In [None]:
def plot_audio_comparison_with_melspec(original, transformed, sample_rate=44100, title="Audio Comparison", 
                                     n_fft=2048, hop_length=512, n_mels=128, figsize=(12, 10)):
    """
    Plot waveforms and mel spectrograms for original and transformed audio side by side.
    Both inputs are expected to be in mono format.
    Both waveform plots will use the same time limits based on the transformed signal.
    
    Args:
        original (np.array): Original audio array (mono)
        transformed (np.array): Transformed audio array (mono)
        sample_rate (int): Sample rate of the audio
        title (str): Main title for the plot
        n_fft (int): FFT window size for mel spectrogramx
        hop_length (int): Hop length for mel spectrogram
        n_mels (int): Number of mel bands to generate
        figsize (tuple): Figure size (width, height)
    """
    # Create figure with subplots: 2 rows (waveform, mel spectrogram) and 2 columns (original, transformed)
    fig, axs = plt.subplots(2, 2, figsize=figsize)
    fig.suptitle(title, fontsize=16)
    
    # Ensure both inputs are 1D arrays (mono)
    if original.ndim > 1:
        raise ValueError("Expected mono audio for 'original', but got multi-dimensional array")
    if transformed.ndim > 1:
        raise ValueError("Expected mono audio for 'transformed', but got multi-dimensional array")
    
    # Calculate time arrays for waveforms
    time_orig = np.arange(len(original)) / sample_rate
    time_trans = np.arange(len(transformed)) / sample_rate
    
    # Get the duration of the transformed signal
    trans_duration = len(transformed) / sample_rate
    
    # Plot waveforms
    axs[0, 0].plot(time_orig, original)
    axs[0, 0].set_title("Original Waveform")
    axs[0, 0].set_xlabel("Time (s)")
    axs[0, 0].set_ylabel("Amplitude")
    axs[0, 0].grid(True)
    # Set x-axis limits to match the transformed signal duration
    axs[0, 0].set_xlim(0, trans_duration)
    
    axs[0, 1].plot(time_trans, transformed)
    axs[0, 1].set_title("Transformed Waveform")
    axs[0, 1].set_xlabel("Time (s)")
    axs[0, 1].set_ylabel("Amplitude")
    axs[0, 1].grid(True)
    # Set x-axis limits to match the transformed signal duration
    axs[0, 1].set_xlim(0, trans_duration)
    
    # Compute mel spectrograms
    mel_orig = librosa.feature.melspectrogram(
        y=original, 
        sr=sample_rate, 
        n_fft=n_fft, 
        hop_length=hop_length,
        n_mels=n_mels
    )
    
    mel_trans = librosa.feature.melspectrogram(
        y=transformed, 
        sr=sample_rate, 
        n_fft=n_fft, 
        hop_length=hop_length,
        n_mels=n_mels
    )
    
    # Convert to dB scale
    S_orig_db = librosa.power_to_db(mel_orig, ref=np.max)
    S_trans_db = librosa.power_to_db(mel_trans, ref=np.max)
    
    # Plot mel spectrograms
    img1 = librosa.display.specshow(
        S_orig_db, 
        x_axis='time', 
        y_axis='mel', 
        sr=sample_rate, 
        hop_length=hop_length, 
        ax=axs[1, 0]
    )
    axs[1, 0].set_title('Original Mel Spectrogram')
    # Set x-axis limits for spectrogram to match transformed duration
    axs[1, 0].set_xlim(0, trans_duration)
    
    img2 = librosa.display.specshow(
        S_trans_db, 
        x_axis='time', 
        y_axis='mel', 
        sr=sample_rate, 
        hop_length=hop_length, 
        ax=axs[1, 1]
    )
    axs[1, 1].set_title('Transformed Mel Spectrogram')
    
    # Add colorbar
    fig.colorbar(img1, ax=axs[1, 0], format='%+2.0f dB')
    fig.colorbar(img2, ax=axs[1, 1], format='%+2.0f dB')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.92)  # Adjust to make room for the main title
    plt.show()

In [None]:
plot_audio_comparison_with_melspec(my_audio, ps_my_audio, sample_rate=my_track.rate, title=f"Original vs Pitch-Shifted Audio n_steps={n_steps}")

### Time Stretch

In [None]:
def time_stretch_audio(audio, stretch_factor=1.0, normalize=True):
    """
    Apply time stretching to audio data in a numpy array.
    
    Args:
        audio (np.ndarray): Audio data as a mono numpy array
        stretch_factor (float): Time stretch factor:
                               - stretch_factor > 1.0: slower/longer audio
                               - stretch_factor < 1.0: faster/shorter audio
                               - stretch_factor = 1.0: unchanged (default)
        normalize (bool): If True, normalizes the audio to the range [-1, 1] after time stretching.

    Returns:
        np.ndarray: Time-stretched audio array
                         
    Raises:
        ValueError: If audio is not a 1D numpy array (mono)
    """
    # Check if audio is mono (1D array)
    if audio.ndim > 1:
        raise ValueError("Expected mono audio array (1D), but got multi-dimensional audio. ")
    
    audio = audio / np.max(np.abs(audio))  # Normalize audio to [-1, 1] before processing
    
    # Apply time stretching to the mono audio array
    # Note: librosa.effects.time_stretch takes rate parameter where:
        # rate > 1 makes audio faster (shorter), rate < 1 makes audio slower (longer)
    # This is the inverse of our stretch_factor parameter, so we use 1/stretch_factor

    ts_audio = librosa.effects.time_stretch(audio, rate=1/stretch_factor)
    # Normalize the audio to [-1, 1] if requested
    if normalize:
        ts_audio = ts_audio / np.max(np.abs(ts_audio))
    return ts_audio

In [None]:
stretch_factor = 1.4
ts_my_audio = time_stretch_audio(my_audio, stretch_factor=stretch_factor, normalize=True)

print(f"Time-stretched audio shape: {ts_my_audio.shape}")
# Play the time-stretched audio
display(Audio(ts_my_audio, rate=my_track.rate))
plot_audio_comparison_with_melspec(my_audio, ts_my_audio, sample_rate=my_track.rate, 
                                     title=f"Original vs Time-Stretched Audio (Factor: {stretch_factor})")

### Add Noise (Not considered)

In [None]:
# def add_noise_audio(audio, noise_level=0.001, normalize=True):
#     """
#     Add Gaussian noise to audio data in a numpy array.
    
#     Args:
#         audio (np.ndarray): Audio data as a mono numpy array
#         sr (int): Sample rate of the audio data
#         noise_level (float): Standard deviation of noise to add
#         normalize (bool): If True, normalizes the audio to the range [-1, 1] after adding noise.
    
#     Returns:
#         np.ndarray: Noise-added audio array
                         
#     Raises:
#         ValueError: If audio is not a 1D numpy array (mono)
#     """
#     # Check if audio is mono (1D array)
#     if audio.ndim > 1:
#         raise ValueError("Expected mono audio array (1D), but got multi-dimensional audio.")
    
#     audio = audio / np.max(np.abs(audio))  # Normalize audio to [-1, 1] before processing
#     # Generate noise matching the audio shape
#     noise = np.random.normal(0, noise_level, audio.shape)
    
#     # Add noise to audio
#     noise_audio = audio + noise
#
#     # Normalize the audio to [-1, 1] if requested
#     if normalize:
#         noise_audio = noise_audio / np.max(np.abs(noise_audio))
#
#     return noise_audio

In [None]:
# noise_level = 0.005  # Standard deviation of noise

# noise_my_audio = add_noise_audio(my_audio, noise_level=noise_level, normaize=True)
# display(Audio(noise_my_audio, rate=my_track.rate))

# plot_audio_comparison_with_melspec(my_audio, noise_my_audio, sample_rate=my_track.rate,
#                                      title=f"Original vs Noise-Added Audio (Noise Level: {noise_level})")

### Dynamic Range Compression

In [None]:
def compress_dynamic_range_audio(audio, threshold=0.2, ratio=2.0, normalize=True):
    """
    Apply dynamic range compression to audio data in a numpy array.
    
    Args:
        audio (np.ndarray): Audio data as a mono numpy array
        sr (int): Sample rate of the audio data
        threshold (float): Compression threshold (0.0 to 1.0)
        ratio (float): Compression ratio (>1.0, higher = more compression)
                      This is the amount of gain reduction applied to signals above the threshold 
    
    Returns:
        np.ndarray: Compressed audio array
                         
    Raises:
        ValueError: If audio is not a 1D numpy array (mono)
    """
    # Check if audio is mono (1D array)
    if audio.ndim > 1:
        raise ValueError("Expected mono audio array (1D), but got multi-dimensional audio.")
    
    # Make a copy to avoid modifying the original
    compressed_audio = audio.copy()

    compressed_audio = compressed_audio / np.max(np.abs(compressed_audio))  # Normalize audio to [-1, 1] before processing
    
    # Calculate absolute values for threshold comparison
    abs_audio = np.abs(compressed_audio)
    
    # Find samples above the threshold
    above_threshold = abs_audio > threshold
    
    # Apply compression to samples above threshold
    if np.any(above_threshold):
        # Calculate the gain reduction for samples above threshold
        # Formula: output = threshold + (input - threshold) / ratio
        excess = abs_audio[above_threshold] - threshold
        compressed_excess = excess / ratio
        new_amplitude = threshold + compressed_excess
        
        # Apply the compression while preserving the sign
        gain_factor = new_amplitude / abs_audio[above_threshold]
        compressed_audio[above_threshold] *= gain_factor
    
    if normalize:
        # Normalize the compressed audio to [-1, 1]
        compressed_audio = compressed_audio / np.max(np.abs(compressed_audio))
    
    return compressed_audio

In [None]:
threshold = 0.6  # Compression threshold
ratio = 12.0  # Compression ratio

com_my_audio = compress_dynamic_range_audio(my_audio, threshold=threshold, ratio=ratio, normalize=True)
print(f"Compressed audio shape: {com_my_audio.shape}")
# Play the compressed audio
print(f'Original audio:')
display(Audio(my_audio, rate=my_track.rate))
print(f'Compressed audio:')
display(Audio(com_my_audio, rate=my_track.rate))
plot_audio_comparison_with_melspec(my_audio, com_my_audio, sample_rate=my_track.rate,
                                     title=f"Original vs Compressed Audio (Threshold: {threshold}, Ratio: {ratio})")

### Reverb

In [None]:
def apply_reverb_audio(audio, sr, room_size=0.5, damping=0.5, normalize=True):
    """
    Apply simple reverb effect to audio data in a numpy array.
    
    Args:
        audio (np.ndarray): Audio data as a mono numpy array
        sr (int): Sample rate of the audio data
        room_size (float): Room size parameter (0-1)
                          Higher values create longer reverb tails
        damping (float): Damping parameter (0-1) 
                        Higher values create faster decay of reverb
        normalize (bool): If True, normalizes the audio to the range [-1, 1] after adding reverb.
    
    Returns:
        np.ndarray: Reverb-added audio array
                         
    Raises:
        ValueError: If audio is not a 1D numpy array (mono)
    """
    # Check if audio is mono (1D array)
    if audio.ndim > 1:
        raise ValueError("Expected mono audio array (1D), but got multi-dimensional audio.")
    
    audio = audio / np.max(np.abs(audio))  # Normalize audio to [-1, 1] before processing

    # Create a simple impulse response for reverb
    ir_length = int(sr * room_size)  # Impulse response length
    
    # Generate exponentially decaying noise as impulse response
    decay = np.exp(-np.arange(ir_length) * damping * 10 / ir_length)
    impulse_response = np.random.normal(0, 0.1, ir_length) * decay
    
    # Apply convolution for mono audio
    convolved = np.convolve(audio, impulse_response, mode='full')
    reverb_audio = convolved[:len(audio)]
    
    # Mix with original audio (30% reverb, 70% original)
    res = 0.7 * audio + 0.3 * reverb_audio

    if normalize:
        res = res / np.max(np.abs(res))

    return res

In [None]:
reverb_room_size = 0.8  # Room size parameter (0-1)
reverb_damping = 0.6  # Damping parameter (0-1)
reverb_my_audio = apply_reverb_audio(my_audio, sr=my_track.rate, 
                                      room_size=reverb_room_size, damping=reverb_damping, normalize=True)
print(f"Reverb audio shape: {reverb_my_audio.shape}")
# Play the reverb audio
display(Audio(reverb_my_audio, rate=my_track.rate))
plot_audio_comparison_with_melspec(my_audio, reverb_my_audio, sample_rate=my_track.rate,
                                     title=f"Original vs Reverb Audio (Room Size: {reverb_room_size}, Damping: {reverb_damping})")

### One Function To Rule Them All

In [None]:
def apply_augmentations(audio, sr, augmentation_types, config=config):
    """
    Apply specified augmentations to audio data.
    
    Args:
        audio (np.ndarray): Audio data as a mono numpy array
        sr (int): Sample rate of the audio data
        augmentation_types (list): A list of strings specifying augmentation types to apply
            Valid types: 'pitch_shift', 'time_stretch', 'compression', 'reverb', 'noise'
        config: Optional configuration object with augmentation parameters.
            If None, uses reasonable default ranges for parameters.
    
    Returns:
        tuple: (augmented_audio, augmentation_description)
              augmented_audio is the processed numpy array
              augmentation_description is a string describing what was applied
    
    Raises:
        ValueError: If augmentation_types is not a list of strings
        ValueError: If an invalid augmentation type is specified
        ValueError: If audio is not a 1D numpy array (mono)
    """
    # Check if audio is mono (1D array)
    if audio.ndim > 1:
        raise ValueError("Expected mono audio array (1D), but got multi-dimensional audio.")
    
    # Validate augmentation_types parameter type
    valid_augmentation_types = ['pitch_shift', 'time_stretch', 'compression', 'reverb', 'noise']
    
    if not isinstance(augmentation_types, list) or not all(isinstance(aug, str) for aug in augmentation_types):
        raise ValueError("augmentation_types must be a list of strings specifying augmentation types. Valid options are: " + ", ".join(valid_augmentation_types))
    
    # Store augmentation choices
    applied_augmentations = []
    

    # Set up default parameter ranges if config is not provided
    if config is None:
        raise ValueError("Config object must be provided with augmentation parameters.")
    
    # Set up augmentation choices based on input parameter
    augmentation_choices = {aug: False for aug in valid_augmentation_types}

    # Process the list of augmentation types
    for aug_type in augmentation_types:
        if aug_type not in valid_augmentation_types:
            raise ValueError(f"Invalid augmentation type: '{aug_type}'. Valid options are: {', '.join(valid_augmentation_types)}")
        augmentation_choices[aug_type] = True
    
    # Start with the original audio
    result = audio.copy()
    
    # Apply pitch shift
    if augmentation_choices['pitch_shift']:
        n_steps = random.uniform(*config.PITCH_SHIFT_RANGE)
        result = pitch_shift_audio(result, sr=sr, n_steps=n_steps, normalize=True)
        applied_augmentations.append(f"Pitch shift: {n_steps:.2f} semitones")
    
    # Apply time stretch
    if augmentation_choices['time_stretch']:
        stretch_factor = random.uniform(*config.TIME_STRETCH_RANGE)
        result = time_stretch_audio(result, stretch_factor=stretch_factor, normalize=True)
        applied_augmentations.append(f"Time stretch: {stretch_factor:.2f}x")
    
    # # Apply noise
    # if augmentation_choices['noise']:
    #     noise_level = random.uniform(*config.NOISE_LEVEL_RANGE)
    #     result = add_noise_audio(result, noise_level=noise_level, normalize=True)
    #     applied_augmentations.append(f"Noise: {noise_level:.4f} std dev")
    
    # Apply compression
    if augmentation_choices['compression']:
        threshold = random.uniform(*config.COMPRESSION_THRESHOLD_RANGE)
        ratio = random.uniform(*config.COMPRESSION_RATIO_RANGE)
        result = compress_dynamic_range_audio(result, threshold=threshold, ratio=ratio, normalize=True )
        applied_augmentations.append(f"Compression: {ratio:.1f}:1 @ {threshold:.2f}")
    
    # Apply reverb
    if augmentation_choices['reverb']:
        room_size = random.uniform(*config.REVERB_ROOM_SIZE_RANGE)
        damping = random.uniform(*config.REVERB_DAMPING_RANGE)
        result = apply_reverb_audio(result, sr=sr, room_size=room_size, damping=damping, normalize=True)
        applied_augmentations.append(f"Reverb: room={room_size:.2f}, damp={damping:.2f}")
    
    # Normalize to prevent clipping
    result = result / np.max(np.abs(result))  # Normalize audio to [-1, 1] range
    
    description = "; ".join(applied_augmentations) if applied_augmentations else "No augmentation"
    
    return result, description

In [None]:
augmentation_types = ['pitch_shift', 'time_stretch', 'compression', 'reverb']


aug_my_audio, aug_description = apply_augmentations(my_audio, sr=my_track.rate, augmentation_types=augmentation_types, config=config)
print(f"Augmented audio shape: {aug_my_audio.shape}")
# Play the augmented audio
display(Audio(aug_my_audio, rate=my_track.rate))
plot_audio_comparison_with_melspec(my_audio, aug_my_audio, sample_rate=my_track.rate,
                                     title=f"Original vs Augmented Audio\n({aug_description})")

## Data Augmentation Pipeline

### Audio Suitability Check

In [None]:
def is_suitable(audio, energy_threshold=0.005, silent_threshold=0.002, min_active_ratio=0.3, window_size=4410):
    """
    Check if an audio array has enough energy and sufficient non-silent segments to be considered suitable for processing.
    
    Args:
        audio (np.ndarray): Audio data as a mono numpy array
        energy_threshold (float): Minimum overall RMS energy threshold (0.0 to 1.0)
        silent_threshold (float): RMS energy threshold to consider a segment non-silent (0.0 to 1.0)
        min_active_ratio (float): Minimum ratio of non-silent segments required (0.0 to 1.0)
        window_size (int): Size of windows to analyze for silent detection (in samples)
    
    Returns:
        bool: True if the audio has enough energy and non-silent segments, False otherwise
    
    Raises:
        ValueError: If audio is not a 1D numpy array (mono)
    """
    # Make a copy to avoid modifying the original
    audio_check = audio.copy()
    
    # # Normalize if requested
    # if normalize and np.max(np.abs(audio_check)) > 0:
    #     audio_check = audio_check / np.max(np.abs(audio_check))
    
    # Calculate overall RMS energy
    overall_rms = np.sqrt(np.mean(audio_check**2))
    
    # If overall energy is too low, return False immediately
    if overall_rms <= energy_threshold:
        return False
    
    # Check for non-silent segments
    num_windows = len(audio_check) // window_size
    
    # If audio is shorter than one window, just use the overall RMS
    if num_windows == 0:
        return overall_rms > silent_threshold
    
    # Count non-silent segments
    non_silent_count = 0
    for i in range(num_windows):
        start = i * window_size
        end = start + window_size
        segment = audio_check[start:end]
        segment_rms = np.sqrt(np.mean(segment**2))
        
        if segment_rms > silent_threshold:
            non_silent_count += 1
    
    # Calculate ratio of non-silent segments
    active_ratio = non_silent_count / num_windows
    
    # Return True if both energy and non-silent criteria are met
    return active_ratio >= min_active_ratio

In [None]:
is_suitable_result = is_suitable(my_audio, energy_threshold=0.005, silent_threshold=0.002, min_active_ratio=0.3, window_size=4410)
print(f"Is the audio suitable for processing? {'Yes' if is_suitable_result else 'No'}")

### Semi Coherent Augmentation - Coherent Augmentation With Varying Parameters

In [None]:
def coherent_augmentation_varying_params(mus_train, idx, config=config, duration=10.0, show_audio=False, max_attempts=10):
    """
    Select a random track from the MUSDB18 training set, extract its stems,
    and apply random augmentations to each stem. Ensures all stems are suitable
    (have enough energy and non-silent segments) before processing.
    
    Args:
        mus_train (musdb.DB): MUSDB18 training set object
        config: Configuration object for augmentation parameters
        duration (float): Duration in seconds to process for each stem
        show_audio (bool): If True, display audio players for original and augmented stems
        max_attempts (int): Maximum number of attempts to find suitable stem segments

    Returns:
        tuple: (original_stems, augmented_stems)
            original_stems: {
                'vocals': np.ndarray (if present),
                'bass': np.ndarray (if present),
                'drums': np.ndarray (if present),
                'other': np.ndarray (if present),
                'mixture': np.ndarray,  # Combined original stems
                'track_name': str  # Name of the track
            }
            augmented_stems: {
                'vocals': np.ndarray (if present),
                'bass': np.ndarray (if present),
                'drums': np.ndarray (if present),
                'other': np.ndarray (if present),
                'mixture': np.ndarray,  # Combined augmented stems
                'descriptions': {'vocals': str, 'bass': str, 'drums': str, 'other': str},
                'track_name': str  # Name of the track with '_aug' suffix
            }
    """
    if idx > len(mus_train.tracks) -1:
        raise ValueError(f"Index {idx} is out of bounds for the training set with {len(mus_train.tracks)} tracks.")

    # Extract the random track
    track = extract_track_by_index(mus_train, idx, to_mono=True)

    # Get sample rate
    sr = track.rate
    print(f"Selected random track: {track.name} (index {idx})")

    # =========================== #
    # VALIDATE DURATION PARAMETER #
    # =========================== #

    if duration is None:
        raise ValueError("Duration must be specified for coherent augmentation.")
    if duration <= 0:
        raise ValueError("Duration must be a positive number.")
    if duration > track.audio.shape[0] / sr:
        raise ValueError(f"Duration {duration} seconds exceeds track length {track.audio.shape[0] / sr:.2f} seconds.")
    if duration < 1.0:
        raise ValueError("Duration must be at least 1 second for coherent augmentation.")
    
    # =========================== #
    # CHECKING STEMS AVAILABILITY #
    # =========================== #

    # Standard stem names in MUSDB18
    stem_names = ['vocals', 'bass', 'drums', 'other']

    # Check if the track has all the required stems
    missing_stems = [stem for stem in stem_names if stem not in track.sources]
    if missing_stems:
        raise ValueError(f"Track '{track.name}' is missing the following stems: {', '.join(missing_stems)}")

    # =============================== #
    # FIND SUITABLE SEGMENT FOR STEMS #
    # =============================== #
    
    # Calculate maximum possible offset to ensure we can extract the full duration
    max_offset = track.audio.shape[0] / sr - duration
    
    # Initialize variables
    all_stems_suitable = False
    attempt_count = 0
    
    while not all_stems_suitable and attempt_count < max_attempts:
        attempt_count += 1
        print(f"Attempt {attempt_count}/{max_attempts} to find suitable stems segment...")
        
        # Select a random offset time to start processing
        offset = random.uniform(0, max_offset)
        start_sample = int(offset * sr)
        end_sample = start_sample + int(duration * sr)
        
        # Check if all stems in this segment are suitable
        stems_suitability = {}
        all_suitable = True
        
        for stem in stem_names:
            # Get the original audio for this stem
            original_audio = track.sources[stem].audio[start_sample:end_sample]
            # Check if this stem segment is suitable
            stems_suitability[stem] = is_suitable(original_audio, energy_threshold=0.005, silent_threshold=0.002,
                                                    min_active_ratio=0.3,window_size=4410)

            if not stems_suitability[stem]:
                all_suitable = False
        
        # Print suitability results for this attempt
        print(f"Segment at offset {offset:.2f}s:")
        for stem, suitable in stems_suitability.items():
            print(f"  - {stem}: {'Suitable' if suitable else 'Not suitable'}")
        
        if all_suitable:
            all_stems_suitable = True
            print(f"Found suitable segment after {attempt_count} attempts at offset {offset:.2f}s")
        elif attempt_count == max_attempts:
            print(f"Warning: Could not find segment with all suitable stems after {max_attempts} attempts.")
            print("Proceeding with the last segment checked.")
    
    # ============================== #
    # INITIALIZE RETURN DICTIONARIES #
    # ============================== #

    original_stems = {}
    augmented_stems = {}

    augmented_stems['descriptions'] = {}
    original_stems['track_name'] = track.name
    augmented_stems['track_name'] = track.name + "_aug"

    # ================================= #
    # APPLY AUGMENTATIONS FOR ALL STEMS #
    # ================================= #
    for stem in stem_names:
        # Get the original audio for this stem
        original_audio = track.sources[stem].audio[start_sample:end_sample]
        original_audio = original_audio / np.max(np.abs(original_audio))  # Normalize audio to [-1, 1]

        # Create a random list of augmentations to apply
        random_augmentations = []

        # 60% chance to add each augmentation type
        if random.random() < 0.6:
            random_augmentations.append('pitch_shift')
        if random.random() < 0.6:
            random_augmentations.append('time_stretch')
        if random.random() < 0.6:
            random_augmentations.append('compression')
        if random.random() < 0.6:
            random_augmentations.append('reverb')
        # if random.random() < 0.6:
        #     random_augmentations.append('noise')

        # Apply the augmentations
        augmented_audio, description = apply_augmentations(audio=original_audio,sr=sr,augmentation_types=random_augmentations,config=config)

        # Pad the signal if less than the desired duration
        if len(augmented_audio) < int(duration * sr):
            padding_length = int(duration * sr) - len(augmented_audio)
            augmented_audio = np.pad(augmented_audio, (0, padding_length), mode='constant')
        
        # Crop the signal if longer than the desired duration
        if len(augmented_audio) > int(duration * sr):
            augmented_audio = augmented_audio[:int(duration * sr)]

        augmented_audio = augmented_audio / np.max(np.abs(augmented_audio))  # Normalize audio to [-1, 1]

        # Store in return dictionaries
        original_stems[stem] = original_audio
        augmented_stems[stem] = augmented_audio[:int(duration * sr)]  # Ensure we only keep the duration we want
        augmented_stems['descriptions'][stem] = description

        if show_audio:
            print(f"Stem: {stem}")
            display(Audio(original_audio, rate=sr))
            print(f"Augmentation: {description}")
            display(Audio(augmented_audio, rate=sr))
            print()

    # ================================================================= #
    # COMBINE ALL ORIGINAL  AND AUGMENTED STEMS INTO 2 DIFFERENT ARRAYS # 
    # ================================================================= #

    original_combined = np.sum([(original_stems[stem]) for stem in stem_names if stem in original_stems], axis=0) /4



    # Summing all the augmented stems into one array
    print(f"Augmented stems: {', '.join(augmented_stems.keys())}")
    for stem in stem_names:
        print(f"Augmented stem '{stem}' shape: {augmented_stems[stem].shape}")
    augmented_combined = np.sum([(augmented_stems[stem]) for stem in stem_names if stem in augmented_stems], axis=0) / 4

    # Divide by 4 to prevent clipping, as we are summing multiple stems

    # Add both combined mixtures to the results
    original_stems['mixture'] = original_combined
    augmented_stems['mixture'] = augmented_combined

    # =================================== #
    # DISPLAY COMBINED AUDIO IF REQUESTED #
    # =================================== #
    if show_audio:  
        print("\nOriginal mixture of all stems:")
        display(Audio(original_combined, rate=sr))

        print("\nAugmented mixture of all stems:")
        display(Audio(augmented_combined, rate=sr))

    print("Augmentation complete")
    return original_stems, augmented_stems

In [None]:
orig_stems, aug_stems = coherent_augmentation_varying_params(mus_train, idx=1, config=config, duration=10.0, show_audio=True)

In [None]:
print(f"Original Track name: {orig_stems['track_name']}")
print(f"Augmented Track name: {aug_stems['track_name']}\n")

print("Original stems:")
for stem, audio in orig_stems.items():
    if stem != 'track_name':
        print(f"{stem.capitalize()}: {audio.shape} samples")
print("\nAugmented stems:")
for stem, audio in aug_stems.items():
    if stem != 'mixture' and stem != 'descriptions' and stem != 'track_name':
        print(f"{stem.capitalize()}: {audio.shape} samples")

In [None]:
print(f"orig_stems.keys(): {orig_stems.keys()}")
print(f"aug_stems.keys(): {aug_stems.keys()}")
print(f"aug_stems['descriptions']: {aug_stems['descriptions']}")

In [None]:
for stem in aug_stems.keys():
    if stem != 'descriptions' and stem != 'track_name':
        plot_audio_comparison_with_melspec(
            orig_stems[stem], 
            aug_stems[stem], 
            sample_rate=44100, 
            title=f"Original vs Augmented {stem.capitalize()} Audio"
        )

### Coherent Augmentation

In [None]:
def coherent_augmentation_fixed_params(mus_train, idx, config, duration=10.0, show_audio=False, max_attempts=10):
    """
    Select a random track from the MUSDB18 training set, extract its stems,
    and apply the SAME random augmentations to each stem (fixed parameters).
    Ensures all stems are suitable (have enough energy and non-silent segments) before processing.
    
    Args:
        mus_train (musdb.DB): MUSDB18 training set object
        idx (int): Index of the track to extract from the dataset
        config: Configuration object for augmentation parameters
        duration (float): Duration in seconds to process for each stem
        show_audio (bool): If True, display audio players for original and augmented stems
        max_attempts (int): Maximum number of attempts to find suitable stem segments

    Returns:
        tuple: (original_stems, augmented_stems)
            original_stems: {
                'vocals': np.ndarray (if present),
                'bass': np.ndarray (if present),
                'drums': np.ndarray (if present),
                'other': np.ndarray (if present),
                'mixture': np.ndarray,  # Combined original stems
                'track_name': str  # Name of the track
            }
            augmented_stems: {
                'vocals': np.ndarray (if present),
                'bass': np.ndarray (if present),
                'drums': np.ndarray (if present),
                'other': np.ndarray (if present),
                'mixture': np.ndarray,  # Combined augmented stems
                'descriptions': {'vocals': str, 'bass': str, 'drums': str, 'other': str},
                'track_name': str  # Name of the track with '_aug' suffix
            }
    """
    if idx >= len(mus_train.tracks) -1:
        raise ValueError(f"Index {idx} is out of bounds for the training set with {len(mus_train.tracks)} tracks.")
    
    track = extract_track_by_index(mus_train, idx, to_mono=True)
    sr = track.rate
    print(f"Selected track: {track.name} (index {idx})")

    # =========================== #
    # VALIDATE DURATION PARAMETER #
    # =========================== #
    if duration is None:
        raise ValueError("Duration must be specified for coherent augmentation.")
    if duration <= 0:
        raise ValueError("Duration must be a positive number.")
    if duration > track.audio.shape[0] / sr:
        raise ValueError(f"Duration {duration} seconds exceeds track length {track.audio.shape[0] / sr:.2f} seconds.")
    if duration < 1.0:
        raise ValueError("Duration must be at least 1 second for coherent augmentation.")

    # =========================== #
    # CHECKING STEMS AVAILABILITY #
    # =========================== #
    stem_names = ['vocals', 'bass', 'drums', 'other']
    missing_stems = [stem for stem in stem_names if stem not in track.sources]
    if missing_stems:
        raise ValueError(f"Track '{track.name}' is missing the following stems: {', '.join(missing_stems)}")

    # =============================== #
    # FIND SUITABLE SEGMENT FOR STEMS #
    # =============================== #
    
    # Calculate maximum possible offset to ensure we can extract the full duration
    max_offset = track.audio.shape[0] / sr - duration
    
    # Initialize variables
    all_stems_suitable = False
    attempt_count = 0
    
    while not all_stems_suitable and attempt_count < max_attempts:
        attempt_count += 1
        print(f"Attempt {attempt_count}/{max_attempts} to find suitable stems segment...")
        
        # Select a random offset time to start processing
        offset = random.uniform(0, max_offset)
        start_sample = int(offset * sr)
        end_sample = start_sample + int(duration * sr)
        
        # Check if all stems in this segment are suitable
        stems_suitability = {}
        all_suitable = True
        
        for stem in stem_names:
            # Get the original audio for this stem
            original_audio = track.sources[stem].audio[start_sample:end_sample]
            
            # Check if this stem segment is suitable
            stems_suitability[stem] = is_suitable(
                original_audio, 
                energy_threshold=0.005, 
                silent_threshold=0.002, 
                min_active_ratio=0.3, 
                window_size=4410
            )
            
            if not stems_suitability[stem]:
                all_suitable = False
        
        # Print suitability results for this attempt
        print(f"Segment at offset {offset:.2f}s:")
        for stem, suitable in stems_suitability.items():
            print(f"  - {stem}: {'Suitable' if suitable else 'Not suitable'}")
        
        if all_suitable:
            all_stems_suitable = True
            print(f"Found suitable segment after {attempt_count} attempts at offset {offset:.2f}s")
        elif attempt_count == max_attempts:
            print(f"Warning: Could not find segment with all suitable stems after {max_attempts} attempts.")
            print("Proceeding with the last segment checked.")

    # ============================== #
    # INITIALIZE RETURN DICTIONARIES #
    # ============================== #
    original_stems = {}
    augmented_stems = {}
    augmented_stems['descriptions'] = {}
    original_stems['track_name'] = track.name
    augmented_stems['track_name'] = track.name + "_aug"


    # ================================================ #
    # DECIDE ON FIXED AUGMENTATION TYPES AND PARAMETERS #
    # ================================================ #
    fixed_augmentations = []
    
    # 60% chance to add each augmentation type
    if random.random() < 0.6:
        fixed_augmentations.append('pitch_shift')
    if random.random() < 0.6:
        fixed_augmentations.append('time_stretch')
    if random.random() < 0.6:
        fixed_augmentations.append('compression')
    if random.random() < 0.6:
        fixed_augmentations.append('reverb')
    # if random.random() < 0.6:
    #     fixed_augmentations.append('noise')
    
    # Pre-generate all the random parameters
    fixed_params = {}
    
    if 'pitch_shift' in fixed_augmentations:
        fixed_params['pitch_shift'] = random.uniform(*config.PITCH_SHIFT_RANGE)

    if 'time_stretch' in fixed_augmentations:
        fixed_params['time_stretch'] = random.uniform(*config.TIME_STRETCH_RANGE)
    
    if 'compression' in fixed_augmentations:
        fixed_params['compression_threshold'] = random.uniform(*config.COMPRESSION_THRESHOLD_RANGE)
        fixed_params['compression_ratio']     = random.uniform(*config.COMPRESSION_RATIO_RANGE)
    
    if 'reverb' in fixed_augmentations:
        fixed_params['reverb_room_size'] = random.uniform(*config.REVERB_ROOM_SIZE_RANGE)
        fixed_params['reverb_damping']   = random.uniform(*config.REVERB_DAMPING_RANGE)
    
    # if 'noise' in fixed_augmentations:
    #     fixed_params['noise_level'] = random.uniform(*config.NOISE_LEVEL_RANGE)
    
    # ================================= #
    # APPLY AUGMENTATIONS FOR ALL STEMS #
    # ================================= #
    for stem in stem_names:
        # Get the original audio for this stem
        original_audio = track.sources[stem].audio[start_sample:end_sample]
        
        # Start with the original audio
        augmented_audio = original_audio.copy()
        applied_augmentations = []
        
        # Apply each selected augmentation with fixed parameters
        if 'pitch_shift' in fixed_augmentations:
            n_steps = fixed_params['pitch_shift']
            augmented_audio = pitch_shift_audio(augmented_audio, sr=sr, n_steps=n_steps)
            applied_augmentations.append(f"Pitch shift: {n_steps:.2f} semitones")
        
        if 'time_stretch' in fixed_augmentations:
            stretch_factor = fixed_params['time_stretch']
            augmented_audio = time_stretch_audio(augmented_audio, stretch_factor=stretch_factor)
            applied_augmentations.append(f"Time stretch: {stretch_factor:.2f}x")
        
        # if 'noise' in fixed_augmentations:
        #     noise_level = fixed_params['noise_level']
        #     augmented_audio = add_noise_audio(augmented_audio, noise_level=noise_level)
        #     applied_augmentations.append(f"Noise: {noise_level:.4f} std dev")
        
        if 'compression' in fixed_augmentations:
            threshold = fixed_params['compression_threshold']
            ratio = fixed_params['compression_ratio']
            augmented_audio = compress_dynamic_range_audio(augmented_audio, threshold=threshold, ratio=ratio)
            applied_augmentations.append(f"Compression: {ratio:.1f}:1 @ {threshold:.2f}")
        
        if 'reverb' in fixed_augmentations:
            room_size = fixed_params['reverb_room_size']
            damping = fixed_params['reverb_damping']
            augmented_audio = apply_reverb_audio(augmented_audio, sr=sr, room_size=room_size, damping=damping)
            applied_augmentations.append(f"Reverb: room={room_size:.2f}, damp={damping:.2f}")
        
        # # Normalize to prevent clipping
        # max_val = np.max(np.abs(augmented_audio))
        # if max_val > 0.95:
        #     augmented_audio = augmented_audio / max_val * 0.95
        #     applied_augmentations.append("Normalized")
        
        description = "; ".join(applied_augmentations) if applied_augmentations else "No augmentation"
        
        # Ensure the augmented audio is of the correct duration
        if len(augmented_audio) < int(duration * sr):
            # Pad with zeros if shorter than duration
            padding_length = int(duration * sr) - len(augmented_audio)
            augmented_audio = np.pad(augmented_audio, (0, padding_length), mode='constant')
        elif len(augmented_audio) > int(duration * sr):
            # Trim to the specified duration if longer
            augmented_audio = augmented_audio[:int(duration * sr)]

        # Store in return dictionaries
        original_stems[stem] = original_audio/ np.max(np.abs(original_audio))  # Normalize original audio to [-1, 1]
        augmented_stems[stem] = augmented_audio / np.max(np.abs(augmented_audio))  # Normalize augmented audio to [-1, 1]
        augmented_stems['descriptions'][stem] = description

        if show_audio:
            print(f"Stem: {stem}")
            display(Audio(original_audio, rate=sr))
            print(f"Augmentation: {description}")
            display(Audio(augmented_audio, rate=sr))
            print()

    # ====================================================== #
    # COMBINE ALL ORIGINAL AND AUGMENTED STEMS INTO 2 ARRAYS # 
    # ====================================================== #

    original_combined  = np.sum([(original_stems[stem]/4) for stem in stem_names if stem in original_stems], axis=0)
    augmented_combined = np.sum([(augmented_stems[stem]/4) for stem in stem_names if stem in augmented_stems], axis=0)

    # Add both combined mixtures to the results
    original_stems['mixture'] = original_combined
    augmented_stems['mixture'] = augmented_combined

    # =================================== #
    # DISPLAY COMBINED AUDIO IF REQUESTED #
    # =================================== #
    if show_audio:  
        print("\nOriginal mixture of all stems:")
        display(Audio(original_combined, rate=sr))

        print("\nAugmented mixture of all stems:")
        display(Audio(augmented_combined, rate=sr))

    print("Augmentation complete")
    return original_stems, augmented_stems

In [None]:
orig_stems_fixed, aug_stems_fixed = coherent_augmentation_fixed_params(mus_train, idx=1, config=config, duration=10.0, show_audio=True)
print(f"Original Track name: {orig_stems_fixed['track_name']}")
print(f"Augmented Track name: {aug_stems_fixed['track_name']}\n")

print("Original stems:")
for stem, audio in orig_stems_fixed.items():
    if stem != 'track_name':
        print(f"{stem.capitalize()}: {audio.shape} samples")

print("Augmented stems:")
for stem, audio in aug_stems_fixed.items():
    if stem != 'mixture' and stem != 'descriptions' and stem != 'track_name':
        print(f"{stem.capitalize()}: {audio.shape} samples")

In [None]:
for stem in aug_stems_fixed.keys():
    if stem != 'descriptions' and stem != 'track_name':
        plot_audio_comparison_with_melspec(
            orig_stems_fixed[stem], 
            aug_stems_fixed[stem], 
            sample_rate=44100, 
            title=f"Original vs Augmented {stem.capitalize()} Audio"
        )

### Incoherent Augmentation

In [None]:
def incoherent_augmentation(mus_train, config=config, duration=10.0, show_audio=False, max_attempts=10):
    """
    Select four random stems from FOUR DIFFERENT tracks in the MUSDB18 training set,
    apply random augmentations to each stem, and combine them into a mixture.
    Ensures all stems are suitable (have enough energy and non-silent segments) before processing.
    
    Args:
        mus_train (musdb.DB): MUSDB18 training set object
        config: Configuration object for augmentation parameters
        duration (float): Duration in seconds to process for each stem
        show_audio (bool): If True, display audio players for original and augmented stems
        max_attempts (int): Maximum number of attempts to find suitable stem segments per track

    Returns:
        tuple: (original_stems, augmented_stems)
            original_stems: {
                'vocals': np.ndarray (if present),
                'bass': np.ndarray (if present),
                'drums': np.ndarray (if present),
                'other': np.ndarray (if present),
                'mixture': np.ndarray,  # Combined original stems
                'track_name': INCOHERENT_TRACK_NAME + random integer  # Name of the track
            }
            augmented_stems: {
                'vocals': np.ndarray (if present),
                'bass': np.ndarray (if present),
                'drums': np.ndarray (if present),
                'other': np.ndarray (if present),
                'mixture': np.ndarray,  # Combined augmented stems
                'descriptions': {'vocals': str, 'bass': str, 'drums': str, 'other': str},
                'track_name': {stem: track_name + "_aug"}  # Names of the tracks for each stem with '_aug' suffix
            }
    """
    # Standard stem names in MUSDB18
    stem_names = ['vocals', 'bass', 'drums', 'other']
    
    # =========================== #
    # VALIDATE DURATION PARAMETER #
    # =========================== #
    if duration is None:
        raise ValueError("Duration must be specified for incoherent augmentation.")
    if duration <= 0:
        raise ValueError("Duration must be a positive number.")
    if duration < 1.0:
        raise ValueError("Duration must be at least 1 second for incoherent augmentation.")
    
    # Initialize return dictionaries
    original_stems = {}
    augmented_stems = {}
    augmented_stems['descriptions'] = {}
    original_stems['track_name'] = "INCOHERENT_TRACK_NAME_" + str(random.randint(1, 9999))
    augmented_stems['track_name'] = "AUG_" + original_stems['track_name']
    
    # For each stem type, select a random track and extract the stem
    for stem in stem_names:
        stem_found = False
        attempts = 0
        
        while not stem_found and attempts < max_attempts * 3:  # More attempts since we need to match 4 different tracks
            attempts += 1
            
            # Select a random track
            track_idx = random.randint(0, len(mus_train.tracks) - 1)
            track = extract_track_by_index(mus_train, track_idx, to_mono=True)
            sr = track.rate
            
            # Check if the track has this stem
            if stem not in track.sources:
                raise ValueError(f"Track '{track.name}' is missing the '{stem}' stem.")
                
            # Check if the track is long enough
            if track.audio.shape[0] / sr < duration:
                raise ValueError(f"Track '{track.name}' is too short for the specified duration of {duration} seconds.")
            
            # Calculate maximum possible offset for this track
            max_offset = track.audio.shape[0] / sr - duration
            
            # Try different random offsets to find a suitable segment
            stem_segment_attempts = 0
            while stem_segment_attempts < max_attempts:
                stem_segment_attempts += 1
                
                # Select a random offset
                offset = random.uniform(0, max_offset)
                start_sample = int(offset * sr)
                end_sample = start_sample + int(duration * sr)
                
                # Get the stem audio at this offset
                stem_audio = track.sources[stem].audio[start_sample:end_sample]
                
                # Check if this stem segment is suitable
                if is_suitable(stem_audio, 
                              energy_threshold=0.005, 
                              silent_threshold=0.002, 
                              min_active_ratio=0.3, 
                              window_size=4410):
                    # We found a suitable segment for this stem
                    print(f"Found suitable {stem} from track '{track.name}' at offset {offset:.2f}s")
                    original_stems[stem] = stem_audio
                    stem_found = True
                    break
            
            if stem_found:
                break
        
        if not stem_found:
            raise ValueError(f"Could not find a suitable {stem} segment after {attempts} attempts across different tracks.")
        
        # Create a random list of augmentations to apply
        random_augmentations = []
        
        # 60% chance to add each augmentation type
        if random.random() < 0.6:
            random_augmentations.append('pitch_shift')
        if random.random() < 0.6:
            random_augmentations.append('time_stretch')
        if random.random() < 0.6:
            random_augmentations.append('compression')
        if random.random() < 0.6:
            random_augmentations.append('reverb')
        
        # Apply the augmentations
        augmented_audio, description = apply_augmentations(
            audio=original_stems[stem],
            sr=sr,
            augmentation_types=random_augmentations,
            config=config
        )

        # Ensure the augmented audio is the correct duration
        if len(augmented_audio) < int(duration * sr):
            # Pad with zeros if shorter than duration
            padding_length = int(duration * sr) - len(augmented_audio)
            augmented_audio = np.pad(augmented_audio, (0, padding_length), mode='constant')
        elif len(augmented_audio) > int(duration * sr):
            # Trim to the specified duration if longer
            augmented_audio = augmented_audio[:int(duration * sr)]
        
        # Store in return dictionaries
        augmented_stems[stem] = augmented_audio
        augmented_stems['descriptions'][stem] = description
        
        if show_audio:
            print(f"Stem: {stem} from track '{track.name}'")
            display(Audio(original_stems[stem], rate=sr))
            print(f"Augmentation: {description}")
            display(Audio(augmented_audio, rate=sr))
            print()
    
    # ====================================================== #
    # COMBINE ALL ORIGINAL AND AUGMENTED STEMS INTO 2 ARRAYS # 
    # ====================================================== #
    
    original_combined = np.sum([(original_stems[stem]/4) for stem in stem_names], axis=0) 
    augmented_combined = np.sum([(augmented_stems[stem]/4) for stem in stem_names], axis=0)


    # Add both combined mixtures to the results
    original_stems['mixture'] = original_combined
    augmented_stems['mixture'] = augmented_combined
    
    # =================================== #
    # DISPLAY COMBINED AUDIO IF REQUESTED #
    # =================================== #
    if show_audio:  
        print("\nOriginal mixture of stems from different tracks:")
        display(Audio(original_combined, rate=sr))
        
        print("\nAugmented mixture of stems from different tracks:")
        display(Audio(augmented_combined, rate=sr))
    
    print("Incoherent augmentation complete")
    return original_stems, augmented_stems

In [None]:
incoherent_orig_stems, incoherent_aug_stems = incoherent_augmentation(mus_train, config=config, duration=10.0, show_audio=True)

print(f"Original Track name: {incoherent_orig_stems['track_name']}")
print(f"Augmented Track name: {incoherent_aug_stems['track_name']}\n")

print("Original stems from incoherent augmentation:")
for stem, audio in incoherent_orig_stems.items():
    if stem != 'mixture' and stem != 'track_name':
        print(f"{stem.capitalize()}: {audio.shape} samples")
print("Augmented stems from incoherent augmentation:")
for stem, audio in incoherent_aug_stems.items():
    if stem != 'mixture' and stem != 'descriptions' and stem != 'track_name':
        print(f"{stem.capitalize()}: {audio.shape} samples")


In [None]:
for stem in incoherent_aug_stems.keys():
    if stem != 'descriptions' and stem != 'track_name':
        plot_audio_comparison_with_melspec(
            incoherent_orig_stems[stem], 
            incoherent_aug_stems[stem], 
            sample_rate=44100, 
            title=f"Original vs Augmented {stem.capitalize()} Audio (Incoherent)"
        )

In [None]:
print(incoherent_orig_stems.items())

### Save augmented stems

In [None]:
def save_augmented_stems(augmented_stems, output_dir):
    """
    Save the augmented stems to the specified output directory.
    
    Args:
        augmented_stems (dict): Dictionary containing augmented stems and their descriptions
        output_dir (str): Directory where the stems will be saved
    """
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'mixture'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'vocals'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'bass'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'drums'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'other'), exist_ok=True)

    aug_track_name = augmented_stems['track_name']

    for stem, audio in augmented_stems.items():
        if stem not in ['descriptions', 'track_name']:
            file_path = os.path.join(output_dir, stem, f"{aug_track_name}_{stem}.wav")
            sf.write(file_path, audio, 44100)
            print(f"Saved {stem} stem to {file_path}")
    
    print(f"All augmented stems for '{aug_track_name}' saved to {output_dir}")
    return 

In [None]:
OUTPUT_DIR = "/Users/agathe/Desktop/Coherent_Augmentation"   # Output directory for augmented files
save_augmented_stems(incoherent_aug_stems, OUTPUT_DIR)
print("\n")
save_augmented_stems(aug_stems_fixed, OUTPUT_DIR)

## Data Generation

### Coherent Augmentation

In [None]:
for idx in range(0, len(mus_train.tracks), 40): # Process every 40th track for demonstration
    print(f"\nProcessing track at index {idx}...")
    try:
        orig_stems, aug_stems = coherent_augmentation_fixed_params(mus_train, idx=idx, config=config, duration=10.0, show_audio=False)
        save_augmented_stems(aug_stems, config.OUTPUT_DIR_COHERENT_MIX)
    except Exception as e:
        print(f"Error processing track at index {idx}: {e}")

### Incoherent Augmentation

In [None]:
for idx in range(0, len(mus_train.tracks), 40):  # Process every 40th track for demonstration
    print(f"\nProcessing track at index {idx}...")
    try:
        orig_stems, aug_stems = incoherent_augmentation(mus_train, config=config, duration=10.0, show_audio=False)
        save_augmented_stems(aug_stems, config.OUTPUT_DIR_INCOHERENT_MIX)
    except Exception as e:
        print(f"Error processing track at index {idx}: {e}")