# Hybrid Demucs V4 Model Evaluation

This notebook evaluates the trained Hybrid Demucs model on the test set and includes utilities for separating full-length audio files.

## Contents
1. **Setup** - GPU configuration and imports
2. **Data Generator** - Test data loading
3. **Model Loading** - Load trained model with custom layers
4. **Evaluation** - Compute per-instrument loss on test set
5. **Audio Separation** - Separate a full-length audio file into stems

In [None]:
# ==============================================================================
# GPU Setup and Imports
# ==============================================================================

import os, random, glob

# Configure GPU memory growth before importing TensorFlow
os.environ.pop("TF_GPU_ALLOCATOR", None)

import tensorflow as tf

# Enable memory growth to avoid allocating all GPU memory upfront
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"Memory growth enabled on {len(gpus)} GPU(s)")

import numpy as np
import librosa
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import mixed_precision

# Import custom model components
from demucs_v4_model import (
    demucs_v4_fixed,
    ExpandDims,
    ReduceMean,
    LocalSelfAttention,
    STFT,
    InverseSTFT,
    custom_loss,
)

# ==============================================================================
# Audio Constants and Utilities
# ==============================================================================

SR            = 44_100      # Sample rate (44.1 kHz)
CHUNK_SECS    = 10          # Chunk duration in seconds
CHUNK_SAMPLES = 441_000     # 10 seconds * 44100 samples/sec
PADDED_LEN    = 441_000     # Model input length


def load_mono(fp, sr=SR):
    """Load audio file as mono float32."""
    wav, _ = librosa.load(fp, sr=sr, mono=True)
    return wav.astype(np.float32)


def pad_or_trim(x, tgt_len=PADDED_LEN):
    """Pad with zeros or trim audio to exact target length."""
    if len(x) < tgt_len:
        return np.pad(x, (0, tgt_len - len(x)))
    return x[:tgt_len]

2025-10-29 06:44:11.502260: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-29 06:44:11.516511: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761720251.533903  102393 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761720251.539497  102393 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761720251.552808  102393 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Memory growth on – no per-process hard cap.


## Data Generator

Test data generator that yields (mix, stems) batches for evaluation.

In [None]:
# Target instrument stems (13 total)
INSTRUMENT_NAMES = [
    "Guitar", "Drums", "Piano", "Bass", "Strings (continued)",
    "Organ", "Synth Lead", "Synth Pad", "Chromatic Percussion",
    "Brass", "Pipe", "Reed", "Strings"
]
# Map instrument names to model output keys
MODEL_KEYS = {n: f"instrument_{i+1}" for i, n in enumerate(INSTRUMENT_NAMES)}
print(MODEL_KEYS)


def data_generator(root, batch_size=8):
    """
    Infinite generator yielding (mix, targets_dict) batches for evaluation.
    
    Args:
        root: Path to directory containing track subdirectories
        batch_size: Number of samples per batch
    """
    root = os.path.expanduser(root)
    track_dirs = [d for d in glob.glob(os.path.join(root, '*')) if os.path.isdir(d)]
    n_tracks = len(track_dirs)
    chunk = CHUNK_SAMPLES

    while True:
        random.shuffle(track_dirs)

        for i in range(0, n_tracks, batch_size):
            dirs = track_dirs[i:i + batch_size]

            mixes   = []
            targets = {k: [] for k in MODEL_KEYS.values()}

            for d in dirs:
                # Load the mix audio file
                mix_files = [f for f in os.listdir(d) if 'mix_chunk' in f.lower()]
                if not mix_files:
                    continue
                mix_full = load_mono(os.path.join(d, mix_files[0]))

                # random starting offset (if long enough)
                if len(mix_full) > chunk:
                    start = np.random.randint(0, len(mix_full) - chunk + 1)
                    mix_clip = mix_full[start:start + chunk]
                else:
                    mix_clip = pad_or_trim(mix_full, chunk)

                # Peak-normalize for consistent scaling
                peak = np.max(np.abs(mix_clip)) + 1e-7
                mix_clip /= peak

                # Load and normalize each instrument stem
                stem_dict = {}
                for name in INSTRUMENT_NAMES:
                    fmatch = next(
                        (f for f in os.listdir(d)
                         if name.lower() in f.lower() and '_chunk_' in f.lower()),
                        None
                    )
                    if fmatch:
                        full = load_mono(os.path.join(d, fmatch))
                        if len(full) > chunk:
                            stem = full[start:start + chunk]
                        else:
                            stem = pad_or_trim(full, chunk)
                        stem = stem / peak  # Use same normalization as mix
                    else:
                        stem = np.zeros(chunk, dtype=np.float32)
                    stem_dict[name] = stem

                # Ensure exact length for model input
                mix_pad = pad_or_trim(mix_clip, PADDED_LEN)
                mixes.append(mix_pad)

                for name in INSTRUMENT_NAMES:
                    targets[MODEL_KEYS[name]].append(
                        pad_or_trim(stem_dict[name], PADDED_LEN)[..., None]
                    )

            if not mixes:
                continue  # Skip empty batch

            # Convert to numpy arrays with channel dimension
            mix_batch = np.array(mixes, dtype=np.float32)[..., None]
            tgt_batch = {k: np.array(v, dtype=np.float32) for k, v in targets.items()}

            yield mix_batch, tgt_batch

{'Guitar': 'instrument_1', 'Drums': 'instrument_2', 'Piano': 'instrument_3', 'Bass': 'instrument_4', 'Strings (continued)': 'instrument_5', 'Organ': 'instrument_6', 'Synth Lead': 'instrument_7', 'Synth Pad': 'instrument_8', 'Chromatic Percussion': 'instrument_9', 'Brass': 'instrument_10', 'Pipe': 'instrument_11', 'Reed': 'instrument_12', 'Strings': 'instrument_13'}


## Load Trained Model

Load the saved model with custom layer definitions.

In [None]:
# MSE loss function for model compilation
def custom_loss(y_true, y_pred):
    """Mean Squared Error loss for waveform reconstruction."""
    return tf.reduce_mean(tf.square(y_true - y_pred))


# Load pre-trained model with custom layer definitions
reloaded = tf.keras.models.load_model(
    "demucs_v4_fixed_model.keras",
    custom_objects={
        "ExpandDims": ExpandDims,
        "ReduceMean": ReduceMean,
        "LocalSelfAttention": LocalSelfAttention,
        "STFT": STFT,
        "InverseSTFT": InverseSTFT,
        "custom_loss": custom_loss,
    },
    compile=False
)

# Enable mixed precision for faster inference
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")




I0000 00:00:1761719416.448287   78469 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78761 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:61:00.0, compute capability: 9.0
I0000 00:00:1761719416.451082   78469 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 78761 MB memory:  -> device: 1, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:62:00.0, compute capability: 9.0
I0000 00:00:1761719416.452652   78469 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 78761 MB memory:  -> device: 2, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:63:00.0, compute capability: 9.0
I0000 00:00:1761719416.454238   78469 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 78761 MB memory:  -> device: 3, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:64:00.0, compute capability: 9.0
I0000 00:00:1761719416.455772   78469 gpu_device.cc:2019] Create

## Model Evaluation

Evaluate the model on the test dataset and report per-instrument loss.

In [None]:
# Evaluation configuration
NUM_INST   = 13
BATCH_SIZE = 8
TEST_STEPS = 3545  # Number of evaluation batches

test_dir = '~/madari3/gcs-bucket/Slakh_Dataset_Chunked/test_chunked'

# Compile model for evaluation
reloaded.compile(optimizer='adam', loss=custom_loss, jit_compile=False, run_eagerly=False)






In [None]:
# Create test dataset with prefetching
test_ds = tf.data.Dataset.from_generator(
    lambda: data_generator(test_dir, BATCH_SIZE),
    output_signature=(
        tf.TensorSpec(shape=(None, PADDED_LEN, 1), dtype=tf.float32),
        {k: tf.TensorSpec(shape=(None, PADDED_LEN, 1), dtype=tf.float32)
         for k in MODEL_KEYS.values()}
    )
).prefetch(tf.data.AUTOTUNE)

# Run evaluation
from tensorflow.keras.callbacks import ProgbarLogger
results = reloaded.evaluate(
    test_ds,
    steps=TEST_STEPS,
    callbacks=[ProgbarLogger()],
    verbose=1
)

# Print results
print(f"\nTotal loss: {results[0]:.4f}")
print("\nPer-instrument losses:")
for i in range(1, len(results)):
    model_key = f"instrument_{i}"
    for key in MODEL_KEYS:
        if MODEL_KEYS[key] == model_key:
            print(f"  {key}: {results[i]:.4f}")
            break

made it
here


I0000 00:00:1753765925.681847    2716 cuda_dnn.cc:529] Loaded cuDNN version 90800
I0000 00:00:1753765930.255780    2716 service.cc:152] XLA service 0x767c7c00bf40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1753765930.255821    2716 service.cc:160]   StreamExecutor device (0): NVIDIA H100 PCIe, Compute Capability 9.0
I0000 00:00:1753765931.013276    2716 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m 465/3545[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m1:41:27[0m 2s/step - instrument_10_loss: nan - instrument_11_loss: nan - instrument_12_loss: nan - instrument_13_loss: nan - instrument_1_loss: nan - instrument_2_loss: nan - instrument_3_loss: nan - instrument_4_loss: nan - instrument_5_loss: nan - instrument_6_loss: nan - instrument_7_loss: nan - instrument_8_loss: nan - instrument_9_loss: nan - loss: nan

KeyboardInterrupt: 

## Audio Separation Utility

Separate a full-length audio file into individual instrument stems using overlap-add processing for seamless output.

In [None]:
import os, glob, math, random, sys
import numpy as np
import soundfile as sf
import tensorflow as tf
from pathlib import Path

# ==============================================================================
# Separation Constants
# ==============================================================================

TARGET_LEN = 441_000  # Model output length (10 seconds at 44.1kHz)
SR = 44_100

INSTRUMENT_NAMES = [
    "Guitar", "Drums", "Piano", "Bass", "Strings (continued)",
    "Organ", "Synth Lead", "Synth Pad", "Chromatic Percussion",
    "Brass", "Pipe", "Reed", "Strings"
]
MODEL_KEYS = {f"instrument_{i+1}": name for i, name in enumerate(INSTRUMENT_NAMES)}

# ==============================================================================
# Utility Functions
# ==============================================================================

def pad_or_trim(x: np.ndarray, length: int) -> np.ndarray:
    """Zero-pad or trim audio to exactly `length` samples."""
    if len(x) >= length:
        return x[:length]
    pad = length - len(x)
    left = pad // 2
    right = pad - left
    return np.pad(x, (left, right))


def chunk_audio(wave: np.ndarray, chunk_size: int):
    """Yield non-overlapping chunks, zero-padding the last chunk if needed."""
    for start in range(0, len(wave), chunk_size):
        end = min(start + chunk_size, len(wave))
        chunk = wave[start:end]

        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))

        yield chunk


def model_chunk_len(model):
    """Get the expected input length from model architecture."""
    tdim = model.input_shape[1]
    if tdim is None:
        # Fallback: find layer with target_len attribute
        tdim = next((getattr(l, "target_len", None) for l in model.layers
                    if hasattr(l, "target_len") and getattr(l, "target_len")), 441024)
    return int(tdim)



def load_audio(path, sr=44100):
    """Load audio file, resample if needed, and convert to mono float32."""
    p = Path(path).expanduser().resolve()
    if not p.exists():
        raise FileNotFoundError(f"Audio file not found: {p}")
    wav, file_sr = sf.read(str(p), always_2d=False)
    if file_sr != sr:
        import librosa
        wav = librosa.resample(wav, orig_sr=file_sr, target_sr=sr)
    if wav.ndim == 2:
        wav = wav.mean(axis=1)  # Convert stereo to mono
    return wav.astype(np.float32)


def separate_long_audio(model, audio_path, output_dir, sr=44100):
    """
    Separate a full-length audio file into instrument stems.
    
    Uses overlap-add processing for seamless output on long audio files.
    
    Args:
        model: Trained Keras model
        audio_path: Path to input audio file
        output_dir: Directory to save separated stems
        sr: Sample rate (default: 44100)
    """
    # Load and prepare audio
    audio_path = Path(audio_path).expanduser().resolve()
    output_dir = Path(output_dir).expanduser().resolve()
    output_dir.mkdir(parents=True, exist_ok=True)
    
    wav, file_sr = sf.read(str(audio_path), always_2d=False)
    if file_sr != sr:
        import librosa
        wav = librosa.resample(wav, orig_sr=file_sr, target_sr=sr)
    if wav.ndim == 2:
        wav = wav.mean(axis=1)
    wav = wav.astype(np.float32)
    global_peak = np.max(np.abs(wav)) + 1e-7

    # Processing parameters
    CHUNK = model_chunk_len(model)
    HOP = TARGET_LEN // 2  # 50% overlap
    WIN = np.sqrt(np.hanning(TARGET_LEN).astype(np.float32))
    OFFSET = (CHUNK - TARGET_LEN) // 2

    # Prepare overlap-add buffers for each instrument
    keys = list(MODEL_KEYS.keys())
    N = len(wav)
    acc = {k: np.zeros(N + CHUNK, np.float32) for k in keys}
    wsum = np.zeros(N + CHUNK, np.float32)

    def predict_chunk(x):
        """Run model prediction on a single chunk and return stems dict."""
        x = x.astype(np.float32)
        x_norm = x / global_peak  # Normalize to match training scale
        if len(x_norm) < CHUNK:
            x_norm = np.pad(x_norm, (0, CHUNK - len(x_norm)))
        
        outs = model.predict(x_norm[np.newaxis, :, np.newaxis], verbose=0)
        
        # Handle different output formats (list vs dict)
        if isinstance(outs, list):
            pred_dict = {k: v for k, v in zip(keys, outs)}
        elif isinstance(outs, dict):
            if keys and keys[0] in outs:
                pred_dict = outs
            else:
                out_names = list(getattr(model, "output_names", []))
                pred_dict = {k: outs[n] for k, n in zip(keys, out_names)} if out_names else {k: v for k, v in zip(keys, outs.values())}
        else:
            raise TypeError("Unsupported prediction return type")
        
        return {k: pred_dict[k][0, :TARGET_LEN, 0] * global_peak for k in keys}

    # Process audio with overlapping windows
    for start in range(0, N, HOP):
        end   = min(start + CHUNK, N)
        chunk = wav[start:end]
        pred  = predict_chunk(chunk)
        w = WIN
        if end - start < CHUNK:             # shorten window at tail if we didn’t pad
            # we still padded for the net, so use full window; crop when adding
            pass

        for k in keys:
            acc[k][start+OFFSET:start+OFFSET+TARGET_LEN] += pred[k] * WIN
            wsum[start+OFFSET:start+OFFSET+TARGET_LEN] += WIN**2

    # normalize by overlap weights and trim to original length
    eps = 1e-8
    out = {k: (acc[k][:N] / np.maximum(wsum[:N], eps)) for k in keys}

    # save with human names
    names = [
        "Guitar","Drums","Piano","Bass","Strings (continued)",
        "Organ","Synth Lead","Synth Pad","Chromatic Percussion",
        "Brass","Pipe","Reed","Strings"
    ]
    key_to_name = {f"instrument_{i+1}": n for i, n in enumerate(names)}
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    for k, y in out.items():
        sf.write(Path(output_dir, f"{key_to_name.get(k,k)}.wav"), y.astype(np.float32), sr, subtype="FLOAT")



# ==============================================================================
# Example Usage
# ==============================================================================

# Uncomment to run separation on a test file:
# separate_long_audio(reloaded, "~/path/to/input.wav", "~/path/to/output/")



I0000 00:00:1761720305.966347  102393 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78761 MB memory:  -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:61:00.0, compute capability: 9.0
I0000 00:00:1761720305.968009  102393 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 78761 MB memory:  -> device: 1, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:62:00.0, compute capability: 9.0
I0000 00:00:1761720305.969526  102393 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 78761 MB memory:  -> device: 2, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:63:00.0, compute capability: 9.0
I0000 00:00:1761720305.971003  102393 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 78761 MB memory:  -> device: 3, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:64:00.0, compute capability: 9.0
I0000 00:00:1761720305.972473  102393 gpu_device.cc:2019] Create