# Hybrid Demucs V4 Training Pipeline

This notebook trains a TensorFlow implementation of the Hybrid Demucs model for music source separation. The model separates mixed audio into 13 individual instrument stems.

## Features
- **Multi-GPU training** using `MirroredStrategy`
- **Mixed precision** (BFloat16) for faster training on H100 GPUs
- **Checkpoint resumption** - automatically resumes from best checkpoint
- **Data augmentation** via random chunk selection from longer audio files

## Dataset
Uses the Slakh2100 dataset, pre-chunked into training examples with separated instrument stems.

In [None]:
# ==============================================================================
# Imports and Configuration
# ==============================================================================

import os, random, glob
import numpy as np
import librosa
import tensorflow as tf  # TensorFlow 2.19+
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import mixed_precision

from demucs_v4_model import demucs_v4_fixed, custom_loss

# ==============================================================================
# Audio Constants
# ==============================================================================

SR            = 44_100      # Sample rate (44.1 kHz)
CHUNK_SECS    = 10          # Duration of each training chunk
CHUNK_SAMPLES = 441_000     # 10 seconds * 44100 samples/sec
PADDED_LEN    = 441_000     # Model input length

# ==============================================================================
# Audio Utility Functions
# ==============================================================================

def load_mono(fp, sr=SR):
    """Load audio file as mono float32 in range [-1, 1]."""
    wav, _ = librosa.load(fp, sr=sr, mono=True)
    wav = wav.astype(np.float32)
    # Sanitize: replace NaN/Inf and clip to valid range
    wav = np.nan_to_num(wav, nan=0.0, posinf=0.0, neginf=0.0)
    wav = np.clip(wav, -1.0, 1.0)
    return wav

def pad_or_trim(x, tgt_len=PADDED_LEN):
    """Pad with zeros or trim audio to exact target length."""
    if len(x) < tgt_len:
        return np.pad(x, (0, tgt_len - len(x)))
    return x[:tgt_len]

2025-10-01 05:15:03.290362: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-01 05:15:03.418521: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759295703.471979   40866 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759295703.487935   40866 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1759295703.589802   40866 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## Data Generator

The generator yields batches of (mix, stems_dict) pairs indefinitely for training. Each batch contains:
- **mix**: The combined audio of all instruments `(batch, 441000, 1)`
- **stems_dict**: Dictionary mapping instrument names to separated audio `{instrument_i: (batch, 441000, 1)}`

In [None]:
# Target instrument stems (13 total)
INSTRUMENT_NAMES = [
    "Guitar", "Drums", "Piano", "Bass", "Strings (continued)",
    "Organ", "Synth Lead", "Synth Pad", "Chromatic Percussion",
    "Brass", "Pipe", "Reed", "Strings"
]
# Map instrument names to model output keys
MODEL_KEYS = {n: f"instrument_{i+1}" for i, n in enumerate(INSTRUMENT_NAMES)}


def data_generator(root, batch_size=8):
    """
    Infinite generator yielding (mix, targets_dict) batches for training.
    
    Args:
        root: Path to directory containing track subdirectories
        batch_size: Number of samples per batch
    
    Yields:
        Tuple of (mix_batch, targets_dict) where:
        - mix_batch: (batch, PADDED_LEN, 1) mixed audio
        - targets_dict: {instrument_key: (batch, PADDED_LEN, 1)} stem audio
    """
    root = os.path.expanduser(root)
    track_dirs = [d for d in glob.glob(os.path.join(root, '*')) if os.path.isdir(d)]
    n_tracks   = len(track_dirs)
    chunk      = CHUNK_SAMPLES

    while True:            # epoch loop
        random.shuffle(track_dirs)

        for i in range(0, n_tracks, batch_size):
            dirs = track_dirs[i:i + batch_size]
            if len(dirs) < batch_size:
                continue

            mixes   = []
            targets = {k: [] for k in MODEL_KEYS.values()}

            for d in dirs:
                # Load the mix audio file
                mix_files  = [f for f in os.listdir(d) if 'mix_chunk' in f.lower()]
                if not mix_files:
                    continue
                mix_full = load_mono(os.path.join(d, mix_files[0]))

                # random starting offset (if long enough)
                if len(mix_full) > chunk:
                    start = np.random.randint(0, len(mix_full) - chunk + 1)
                    mix_clip = mix_full[start:start + chunk]
                else:
                    mix_clip = pad_or_trim(mix_full, chunk)

                # Peak-normalize for training stability
                peak = float(np.nanmax(np.abs(mix_clip)))
                if not np.isfinite(peak) or peak < 1e-4:
                    peak = 1.0  # Treat near-silence as silence
                mix_clip = mix_clip / peak

                # Load and normalize each instrument stem
                stem_dict = {}
                for name in INSTRUMENT_NAMES:
                    fmatch = next(
                        (f for f in os.listdir(d)
                         if f.lower().startswith(name.lower()+'_chunk_')),
                        None
                    )
                    if fmatch:
                        full = load_mono(os.path.join(d, fmatch))
                        full = np.nan_to_num(full, nan=0.0, posinf=0.0, neginf=0.0)
                        if len(full) > chunk:
                            stem = full[start:start + chunk]
                        else:
                            stem = pad_or_trim(full, chunk)
                        stem = stem / peak  # Use same normalization as mix
                    else:
                        stem = np.zeros(chunk, dtype=np.float32)
                    stem_dict[name] = stem

                # Ensure exact length for model input
                mix_pad = pad_or_trim(mix_clip, PADDED_LEN)
                mixes.append(mix_pad)

                for name in INSTRUMENT_NAMES:
                    targets[MODEL_KEYS[name]].append(
                        pad_or_trim(stem_dict[name], PADDED_LEN)[..., None]
                    )

            if not mixes:
                continue  # Skip empty batch

            # Convert to numpy arrays with channel dimension
            mix_batch = np.array(mixes, dtype=np.float32)[..., None]
            tgt_batch = {k: np.array(v, dtype=np.float32) for k, v in targets.items()}

            yield mix_batch, tgt_batch



## Checkpoint Management

Utilities for finding the best checkpoint and exporting the final model.

In [None]:
import os, re, glob, tensorflow as tf

# Configuration - adjust these paths for your project
checkpoint_dir = "demucs_v4_fixed_ckpt"       # Directory for weight checkpoints
output_model   = "demucs_v4_fixed_model.keras"  # Final exported model path
INPUT_SHAPE    = (PADDED_LEN, 1)              # Must match training input shape

def _best_weights_path(ckpt_dir: str) -> str:
    """Find the best checkpoint file by parsing loss values from filenames."""
    paths = glob.glob(os.path.join(ckpt_dir, "*.weights.h5"))
    if not paths:
        raise FileNotFoundError(f"No .weights.h5 files in {ckpt_dir}")

    # Parse loss value from checkpoint filenames (e.g., "ckpt_e01_loss0.123456.weights.h5")
    loss_re = re.compile(r"_(?:val)?[lL]oss([0-9.]+)\.weights\.h5$")

    scored = []
    for p in paths:
        m = loss_re.search(p)
        if m:
            try:
                scored.append((float(m.group(1)), p))
            except ValueError:
                pass

    if scored:
        scored.sort(key=lambda t: t[0])  # Lowest loss = best
        best_loss, best_path = scored[0]
        print(f"Selected best by loss: {os.path.basename(best_path)} (loss={best_loss:.6f})")
        return best_path

    # Fallback: use newest file by modification time
    best_path = max(paths, key=os.path.getmtime)
    print(f"Selected newest weights: {os.path.basename(best_path)}")
    return best_path


def export_model_from_best():
    """Load best weights and export complete model with architecture."""
    best_path = _best_weights_path(checkpoint_dir)

    # Rebuild model graph (must match training architecture)
    model = demucs_v4_fixed(INPUT_SHAPE)
    model.load_weights(best_path)
    print(f"Loaded weights from {best_path}")

    # Save full model (architecture + weights)
    model.save(output_model)
    print(f"Saved full model to {output_model}")


## Training Function

The main training loop with:
- Multi-GPU distribution strategy
- Loss-scaled optimizer for mixed precision stability
- Best and last checkpoint callbacks
- Early stopping on NaN loss

In [None]:
def train_demucs_v4_fixed(data_dir,
                          checkpoint_dir='demucs_v4_fixed_ckpt',
                          batch_size=2,
                          steps_per_epoch=1275,
                          epochs=25):
    """
    Train the Hybrid Demucs V4 model with multi-GPU support.
    
    Args:
        data_dir: Path to training data directory
        checkpoint_dir: Directory to save weight checkpoints
        batch_size: Samples per batch (per replica)
        steps_per_epoch: Training steps before epoch ends
        epochs: Total training epochs
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Multi-GPU training with MirroredStrategy
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        # Optimizer with gradient clipping and loss scaling for mixed precision
        base_opt = tf.keras.optimizers.Adam(learning_rate=1e-5, clipnorm=0.5)
        opt = mixed_precision.LossScaleOptimizer(base_opt, dynamic=True)

        # Build and compile model
        model = demucs_v4_fixed((PADDED_LEN, 1))
        model.compile(
            optimizer=opt,
            loss=custom_loss,
            jit_compile=False,
            run_eagerly=False
        )

        # Resume from checkpoint if available
        try:
            best = _best_weights_path(checkpoint_dir)
            print(f"Resuming from: {os.path.basename(best)}")
            model.load_weights(best)
        except Exception:
            print("Starting training from scratch")

    # Callback: Save best checkpoint by loss
    ckpt_best = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(checkpoint_dir, "ckpt_e{epoch:02d}_loss{loss:.6f}.weights.h5"),
        monitor="loss",
        mode="min",
        save_weights_only=True,
        save_best_only=True,
        save_freq='epoch',
    )

    # Callback: Always save latest checkpoint
    ckpt_last = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(checkpoint_dir, 'last.weights.h5'),
        save_weights_only=True,
        save_best_only=False,
        save_freq='epoch'
    )

    # Callback: Early stopping and NaN detection
    es_cb = EarlyStopping(monitor='loss', patience=3, restore_best_weights=True)
    nan_cb = tf.keras.callbacks.TerminateOnNaN()

    # Create tf.data.Dataset with prefetching
    option = tf.data.Options()
    option.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
    ds = tf.data.Dataset.from_generator(
        lambda: data_generator(data_dir, batch_size),
        output_signature=(
            tf.TensorSpec(shape=(None, PADDED_LEN, 1),  dtype=tf.float32),
            {k: tf.TensorSpec(shape=(None, PADDED_LEN, 1), dtype=tf.float32)
             for k in MODEL_KEYS.values()}
        )
    ).prefetch(tf.data.AUTOTUNE).with_options(option)

    # Train the model
    try:
        model.fit(
            ds,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            callbacks=[ckpt_best, ckpt_last, es_cb, nan_cb],
            verbose=1,
        )
    finally:
        # Export best model when training completes (or fails)
        try:
            export_model_from_best()
        except FileNotFoundError:
            print("No checkpoints found; exporting current model instead.")
            model.save(output_model)


## Run Training

Execute training on the Slakh dataset. Adjust `batch_size`, `steps_per_epoch`, and `epochs` as needed.

In [None]:
# Training configuration
DATA_ROOT = os.path.expanduser('~/madari3/gcs-bucket/Slakh_Dataset_Chunked/train_chunked')

train_demucs_v4_fixed(
    data_dir=DATA_ROOT,
    batch_size=8,
    steps_per_epoch=1275,
    epochs=25
)
    


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3', '/job:localhost/replica:0/task:0/device:GPU:4', '/job:localhost/replica:0/task:0/device:GPU:5', '/job:localhost/replica:0/task:0/device:GPU:6', '/job:localhost/replica:0/task:0/device:GPU:7')


I0000 00:00:1759295710.532050   40866 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78761 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:61:00.0, compute capability: 9.0
I0000 00:00:1759295710.533660   40866 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 78761 MB memory:  -> device: 1, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:62:00.0, compute capability: 9.0
I0000 00:00:1759295710.535184   40866 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 78761 MB memory:  -> device: 2, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:63:00.0, compute capability: 9.0
I0000 00:00:1759295710.536683   40866 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 78761 MB memory:  -> device: 3, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:64:00.0, compute capability: 9.0
I0000 00:00:1759295710.538164   40866 gpu_device.cc:2019] Create

‚Üí Selected best by loss: ckpt_e01_loss2.263522.weights.h5 (loss=2.263522)
Resuming from best: ckpt_e01_loss2.263522.weights.h5


  saveable.load_own_variables(weights_store.get(inner_path))
  saveable.load_own_variables(weights_store.get(inner_path))


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

2025-10-01 05:15:58.743859: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: CANCELLED: GetNextFromShard was cancelled
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
2025-10-01 05:15:58.743974: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: CANCELLED: GetNextFromShard was cancelled
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]] [type.googleapis.com/tensorflow.DerivedStatus='']
2025-10-01 05:15:58.745186: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: CANCELLED: GetNextFromShard was cancelled
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]] [type.googleapis.com/tensorflow.DerivedStatus='']
2025-10-01 05:15:58.747352: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: CANCELLED: GetNextFromShard was cancelled
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 

Epoch 1/25
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 8, group_size = 8, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 8, group_size = 8, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 8, group_size = 8, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 8, group_size = 8, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 8, group_size = 8, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 8, group_size = 8, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective al

I0000 00:00:1759296013.812500   41871 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.838119   41826 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.843943   41789 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.847592   41879 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.851580   41857 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.858075   41866 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.860534   41847 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296013.862422   41831 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1759296021.090434   41805 service.cc:152] XLA service 0x7b7e68811e30 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1759296021.090461   41805 service.cc:160]   StreamExecutor device (0): NVIDIA H100 80GB HBM3, Compute Capability 9.0
I0000 00:00:1759296021.090466   41805 service.

[1m1275/1275[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m6726s[0m 5s/step - instrument_10_loss: 0.0923 - instrument_11_loss: 0.0651 - instrument_12_loss: 0.1392 - instrument_13_loss: 0.0725 - instrument_1_loss: 0.2797 - instrument_2_loss: 0.1242 - instrument_3_loss: 0.2658 - instrument_4_loss: 0.3364 - instrument_5_loss: 0.2206 - instrument_6_loss: 0.1244 - instrument_7_loss: 0.0938 - instrument_8_loss: 0.1555 - instrument_9_loss: 0.0765 - loss: 2.0459
Epoch 2/25
[1m1275/1275[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m5415s[0m 4s/step - instrument_10_loss: 0.0621 - instrument_11_loss: 0.0456 - instrument_12_loss: 0.0948 - instrument_13_loss: 0.0564 - instrument_1_loss: 0.2541 - instrument_2_loss: 0.0957 - instrument_3_loss: 0.2406 - instrument_4_loss: 0.2952 - instrument_5_loss: 0.1967 - instrument_6_loss: 0.1055 - instrument_7_loss: 0.0694 - instrument_8_loss: 0.1302 - instrument_9_loss: 0.0446 

2025-10-02 20:49:12.242960: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: CANCELLED: GetNextFromShard was cancelled
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]] [type.googleapis.com/tensorflow.DerivedStatus='']


‚Üí Selected best by loss: ckpt_e23_loss1.173109.weights.h5 (loss=1.173109)
‚úÖ Loaded weights from demucs_v4_fixed_ckpt/ckpt_e23_loss1.173109.weights.h5
üíæ Saved full model ‚Üí demucs_v4_fixed_model.keras
