In [11]:
# GPU setup (memory growth)
import os
import tensorflow as tf
try:
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f'Enabled memory growth on {len(gpus)} GPU(s)')
    else:
        print('No GPUs available')
except Exception as e:
    print('Memory growth setup skipped:', e)


Enabled memory growth on 8 GPU(s)


In [12]:
# Imports and constants
import os, math
import numpy as np
import librosa, soundfile as sf
import tensorflow as tf

from demucs_v4_model import (
    ExpandDims, ReduceMean, LocalSelfAttention, STFT, InverseSTFT, custom_loss
)

SR = 44100
PADDED_LEN = 441_000

INSTRUMENT_NAMES = [
    'Guitar','Drums','Piano','Bass','Strings (continued)',
    'Organ','Synth Lead','Synth Pad','Chromatic Percussion',
    'Brass','Pipe','Reed','Strings'
]
MODEL_KEYS = {name: f'instrument_{i+1}' for i, name in enumerate(INSTRUMENT_NAMES)}
KEY_TO_NAME = {v: k for k, v in MODEL_KEYS.items()}

def model_chunk_len(m):
    s = m.inputs[0].shape
    return int(s[1]) if s and s[1] else PADDED_LEN


In [13]:
# Load model (single GPU/CPU for robust inference)
model = tf.keras.models.load_model(
    'demucs_v4_fixed_model.keras',
    custom_objects={
        'ExpandDims': ExpandDims,
        'ReduceMean': ReduceMean,
        'LocalSelfAttention': LocalSelfAttention,
        'STFT': STFT,
        'InverseSTFT': InverseSTFT,
        'custom_loss': custom_loss,
    },
    compile=False,
)
# Optional: print output names
model.output_names


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3', '/job:localhost/replica:0/task:0/device:GPU:4', '/job:localhost/replica:0/task:0/device:GPU:5', '/job:localhost/replica:0/task:0/device:GPU:6', '/job:localhost/replica:0/task:0/device:GPU:7')
Using MirroredStrategy with 8 GPUs


ListWrapper(['conv1d_155', 'conv1d_173', 'conv1d_175', 'conv1d_177', 'conv1d_179', 'conv1d_157', 'conv1d_159', 'conv1d_161', 'conv1d_163', 'conv1d_165', 'conv1d_167', 'conv1d_169', 'conv1d_171'])

In [14]:
# Overlap-add separation using output keys from first predict (like Testing)
def separate_long_audio(model, audio_path, output_dir, sr=SR, batch_size=8):
    import os, numpy as np
    audio_path = os.path.expanduser(audio_path)
    output_dir = os.path.expanduser(output_dir)
    os.makedirs(output_dir, exist_ok=True)

    # Load audio
    wav, file_sr = sf.read(audio_path, always_2d=False)
    if file_sr != sr:
        wav = librosa.resample(wav, orig_sr=file_sr, target_sr=sr)
    if wav.ndim == 2:
        wav = wav.mean(axis=1)
    wav = wav.astype(np.float32)

    # Model sizes
    CHUNK = model_chunk_len(model)
    HOP = CHUNK // 2
    WIN = np.hanning(CHUNK).astype(np.float32)
    N = len(wav)

    # Probe first chunk to determine output structure
    head = wav[:CHUNK]
    peak0 = float(np.max(np.abs(head))) or 1.0
    if len(head) < CHUNK:
        head = np.pad(head, (0, CHUNK - len(head)))
    head_in = (head / peak0).astype(np.float32)[np.newaxis, :, np.newaxis]
    test_out = model.predict(head_in, verbose=0)
    if isinstance(test_out, dict):
        keys = list(test_out.keys())
        use_dict = True
    else:
        keys = list(range(len(test_out)))
        use_dict = False

    # Accumulators per head
    if use_dict:
        acc = {k: np.zeros(N + CHUNK, np.float32) for k in keys}
    else:
        acc = [np.zeros(N + CHUNK, np.float32) for _ in keys]
    wsum = np.zeros(N + CHUNK, np.float32)

    batch = []
    starts = []  # (start_idx, peak)

    def flush():
        if not batch:
            return
        xin = np.stack(batch, axis=0)
        out = model.predict(xin, verbose=0)
        for b, (start, peak) in enumerate(starts):
            if use_dict:
                for k in keys:
                    y = out[k][b, :, 0].astype(np.float32) * peak
                    acc[k][start:start+CHUNK] += y * WIN
            else:
                for i in keys:
                    y = out[i][b, :, 0].astype(np.float32) * peak
                    acc[i][start:start+CHUNK] += y * WIN
            wsum[start:start+CHUNK] += WIN
        batch.clear(); starts.clear()

    pos = 0
    while pos < N:
        end = min(pos + CHUNK, N)
        x = wav[pos:end]
        peak = float(np.max(np.abs(x))) if len(x) else 1.0
        if peak < 1e-7: peak = 1.0
        if len(x) < CHUNK: x = np.pad(x, (0, CHUNK - len(x)))
        xin = (x / peak).astype(np.float32)[..., None]
        batch.append(xin); starts.append((pos, peak))
        if len(batch) == batch_size: flush()
        pos += HOP
    flush()

    eps = 1e-8
    if use_dict:
        for k in keys:
            acc[k][:N] = acc[k][:N] / np.maximum(wsum[:N], eps)
            name = KEY_TO_NAME.get(k, str(k))
            sf.write(os.path.join(output_dir, f'{name}.wav'), acc[k][:N].astype(np.float32), sr, subtype='PCM_16')
    else:
        for i in keys:
            y = (acc[i][:N] / np.maximum(wsum[:N], eps)).astype(np.float32)
            name = INSTRUMENT_NAMES[i] if i < len(INSTRUMENT_NAMES) else f"stem_{i+1}"
            sf.write(os.path.join(output_dir, f'{name}.wav'), y, sr, subtype='PCM_16')
    print(f"Saved stems to {output_dir}")


In [15]:
# Run prediction
AUDIO_PATH = '~/madari3/1508 mix.wav'
OUTPUT_DIR = '~/madari3/output'
BATCH_SIZE = 8  # increase to use more GPU memory

# Uncomment to run:
separate_long_audio(model, AUDIO_PATH, OUTPUT_DIR, batch_size=BATCH_SIZE)


2025-09-30 04:29:48.579940: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]


KeyError: 'conv1d_155'