# Hybrid Demucs V4 Inference

This notebook separates a mixed audio file into 13 individual instrument stems using a trained Hybrid Demucs model.

## Features
- **Overlap-add processing** for seamless separation of long audio files
- **Batch prediction** for efficient GPU utilization
- **Automatic resampling** to 44.1kHz if needed

In [None]:
# ==============================================================================
# GPU Configuration
# ==============================================================================

import os
import tensorflow as tf

# Enable memory growth to avoid allocating all GPU memory upfront
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f'Enabled memory growth on {len(gpus)} GPU(s)')
else:
    print('No GPUs available - using CPU')


Enabled memory growth on 8 GPU(s)


In [None]:
# ==============================================================================
# Imports and Constants
# ==============================================================================

import os, math
import numpy as np
import librosa, soundfile as sf
import tensorflow as tf

from demucs_v4_model import (
    ExpandDims, ReduceMean, LocalSelfAttention, STFT, InverseSTFT, custom_loss
)

# Audio parameters
SR = 44100          # Sample rate
PADDED_LEN = 441_000  # 10 seconds at 44.1kHz

# Instrument stem names (13 total)
INSTRUMENT_NAMES = [
    'Guitar', 'Drums', 'Piano', 'Bass', 'Strings (continued)',
    'Organ', 'Synth Lead', 'Synth Pad', 'Chromatic Percussion',
    'Brass', 'Pipe', 'Reed', 'Strings'
]

# Mapping between model output keys and instrument names
MODEL_KEYS = {name: f'instrument_{i+1}' for i, name in enumerate(INSTRUMENT_NAMES)}
KEY_TO_NAME = {v: k for k, v in MODEL_KEYS.items()}


def model_chunk_len(m):
    """Get the expected input length from model architecture."""
    s = m.inputs[0].shape
    return int(s[1]) if s and s[1] else PADDED_LEN


## Load Model

Load the trained model with custom layer definitions.

In [None]:
# Load pre-trained model with custom layer definitions
model = tf.keras.models.load_model(
    'demucs_v4_fixed_model.keras',
    custom_objects={
        'ExpandDims': ExpandDims,
        'ReduceMean': ReduceMean,
        'LocalSelfAttention': LocalSelfAttention,
        'STFT': STFT,
        'InverseSTFT': InverseSTFT,
        'custom_loss': custom_loss,
    },
    compile=False,
)

print(f"Model loaded with {len(model.output_names)} output stems")


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3', '/job:localhost/replica:0/task:0/device:GPU:4', '/job:localhost/replica:0/task:0/device:GPU:5', '/job:localhost/replica:0/task:0/device:GPU:6', '/job:localhost/replica:0/task:0/device:GPU:7')
Using MirroredStrategy with 8 GPUs


ListWrapper(['conv1d_155', 'conv1d_173', 'conv1d_175', 'conv1d_177', 'conv1d_179', 'conv1d_157', 'conv1d_159', 'conv1d_161', 'conv1d_163', 'conv1d_165', 'conv1d_167', 'conv1d_169', 'conv1d_171'])

## Separation Function

The main audio separation function with overlap-add processing for seamless output.

In [None]:
def separate_long_audio(model, audio_path, output_dir, sr=SR, batch_size=8):
    """
    Separate a full-length audio file into instrument stems.
    
    Uses overlap-add processing with batched prediction for efficient GPU utilization.
    
    Args:
        model: Trained Keras model
        audio_path: Path to input audio file
        output_dir: Directory to save separated stems
        sr: Sample rate (default: 44100)
        batch_size: Number of chunks to process in parallel
    """
    audio_path = os.path.expanduser(audio_path)
    output_dir = os.path.expanduser(output_dir)
    os.makedirs(output_dir, exist_ok=True)

    # Load and preprocess audio
    wav, file_sr = sf.read(audio_path, always_2d=False)
    if file_sr != sr:
        wav = librosa.resample(wav, orig_sr=file_sr, target_sr=sr)
    if wav.ndim == 2:
        wav = wav.mean(axis=1)  # Convert stereo to mono
    wav = wav.astype(np.float32)

    # Processing parameters
    CHUNK = model_chunk_len(model)
    HOP = CHUNK // 2  # 50% overlap
    WIN = np.hanning(CHUNK).astype(np.float32)
    N = len(wav)

    # Probe first chunk to determine output structure
    head = wav[:CHUNK]
    peak0 = float(np.max(np.abs(head))) or 1.0
    if len(head) < CHUNK:
        head = np.pad(head, (0, CHUNK - len(head)))
    head_in = (head / peak0).astype(np.float32)[np.newaxis, :, np.newaxis]
    test_out = model.predict(head_in, verbose=0)
    
    if isinstance(test_out, dict):
        keys = list(test_out.keys())
        use_dict = True
    else:
        keys = list(range(len(test_out)))
        use_dict = False

    # Initialize overlap-add accumulators
    if use_dict:
        acc = {k: np.zeros(N + CHUNK, np.float32) for k in keys}
    else:
        acc = [np.zeros(N + CHUNK, np.float32) for _ in keys]
    wsum = np.zeros(N + CHUNK, np.float32)

    # Batch processing buffers
    batch = []
    starts = []

    def flush():
        """Process accumulated batch and add to accumulators."""
        if not batch:
            return
        xin = np.stack(batch, axis=0)
        out = model.predict(xin, verbose=0)
        for b, (start, peak) in enumerate(starts):
            if use_dict:
                for k in keys:
                    y = out[k][b, :, 0].astype(np.float32) * peak
                    acc[k][start:start+CHUNK] += y * WIN
            else:
                for i in keys:
                    y = out[i][b, :, 0].astype(np.float32) * peak
                    acc[i][start:start+CHUNK] += y * WIN
            wsum[start:start+CHUNK] += WIN
        batch.clear()
        starts.clear()

    # Process audio in overlapping chunks
    pos = 0
    while pos < N:
        end = min(pos + CHUNK, N)
        x = wav[pos:end]
        peak = float(np.max(np.abs(x))) if len(x) else 1.0
        if peak < 1e-7:
            peak = 1.0
        if len(x) < CHUNK:
            x = np.pad(x, (0, CHUNK - len(x)))
        xin = (x / peak).astype(np.float32)[..., None]
        batch.append(xin)
        starts.append((pos, peak))
        if len(batch) == batch_size:
            flush()
        pos += HOP
    flush()

    # Normalize and save stems
    eps = 1e-8
    if use_dict:
        for k in keys:
            acc[k][:N] = acc[k][:N] / np.maximum(wsum[:N], eps)
            name = KEY_TO_NAME.get(k, str(k))
            sf.write(os.path.join(output_dir, f'{name}.wav'), 
                    acc[k][:N].astype(np.float32), sr, subtype='PCM_16')
    else:
        for i in keys:
            y = (acc[i][:N] / np.maximum(wsum[:N], eps)).astype(np.float32)
            name = INSTRUMENT_NAMES[i] if i < len(INSTRUMENT_NAMES) else f"stem_{i+1}"
            sf.write(os.path.join(output_dir, f'{name}.wav'), y, sr, subtype='PCM_16')
    
    print(f"Saved {len(keys)} stems to {output_dir}")


## Run Separation

Separate an audio file into stems. Adjust paths and batch size as needed.

In [None]:
# Configuration
AUDIO_PATH = '~/path/to/input.wav'   # Input audio file
OUTPUT_DIR = '~/path/to/output/'     # Output directory for stems
BATCH_SIZE = 8                        # Increase for more GPU memory usage

# Run separation
separate_long_audio(model, AUDIO_PATH, OUTPUT_DIR, batch_size=BATCH_SIZE)


2025-09-30 04:29:48.579940: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]


KeyError: 'conv1d_155'