In [2]:
#@title Mount Drive & load config
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

from pathlib import Path
import yaml, json, os, sys, time

PROJECT_DIR = Path('/content/drive/MyDrive/ddsp-demucs')
CONFIG_PATH = PROJECT_DIR / 'env' / 'config.yaml'
assert CONFIG_PATH.exists(), f"Missing config at {CONFIG_PATH}. Run 00_setup_env.ipynb first."

with open(CONFIG_PATH) as f:
    CFG = yaml.safe_load(f)

MUSDB_ROOT = Path(CFG['dataset']['root'])
STEMS_DIR  = Path(CFG['paths']['stems_dir'])  # e.g., data/stems/demucs_htdemucs44k
STEMS_DIR.mkdir(parents=True, exist_ok=True)

print("Project:", PROJECT_DIR)
print("MUSDB root:", MUSDB_ROOT)
print("Stems out:", STEMS_DIR)


Mounted at /content/drive
Project: /content/drive/MyDrive/ddsp-demucs
MUSDB root: /content/drive/MyDrive/ddsp-demucs/data/musdb18hq
Stems out: /content/drive/MyDrive/ddsp-demucs/data/stems/demucs_htdemucs44k


In [3]:
#@title Install/verify libs (Demucs, musdb, audio I/O)
!pip -q install musdb stempeg museval demucs torchmetrics librosa soundfile -U

import torch, torchaudio, musdb, demucs
print("Torch:", torch.__version__, "| CUDA:", torch.cuda.is_available())
print("Demucs:", demucs.__version__)

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m0.8/1.2 MB[0m [31m23.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.1/87.1 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.6/59.6 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
#@title Configure inference
from dataclasses import dataclass

@dataclass
class InferenceConfig:
    model_name: str = "htdemucs"    # or "htdemucs_ft" / "demucs48_hq" etc.
    subsets: tuple = ("train", "test")  # ("train",) or ("test",)
    save_accompaniment: bool = False
    write_mono: bool = True         # saves vocals.mono.wav (+ acc.mono.wav if enabled)
    segment_seconds: int = 0        # 0 = full track; >0 = chunk into N-sec segments (lower VRAM)
    overlap: float = 0.1            # Demucs overlap for apply_model
    split: bool = True              # Demucs time-chunking (helps long tracks)
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    sample_rate_out: int | None = None   # None = keep original; else resample after separation
    resume: bool = True             # skip tracks that already have vocals files
    write_metadata_json: bool = True
    manifest_csv: Path = STEMS_DIR / "manifest.csv"  # appended progressively

INF = InferenceConfig()
INF


InferenceConfig(model_name='htdemucs', subsets=('train', 'test'), save_accompaniment=False, write_mono=True, segment_seconds=0, overlap=0.1, split=True, device='cuda', sample_rate_out=None, resume=True, write_metadata_json=True, manifest_csv=PosixPath('/content/drive/MyDrive/ddsp-demucs/data/stems/demucs_htdemucs44k/manifest.csv'))

In [5]:
#@title Helpers (downmix, resample, save, manifest)
import numpy as np
import pandas as pd
import soundfile as sf

def downmix_mono(x: torch.Tensor) -> torch.Tensor:
    # x: (C, T) -> (1, T)
    return x.mean(0, keepdim=True)

def ensure_sr(x: torch.Tensor, sr_src: int, sr_dst: int) -> tuple[torch.Tensor, int]:
    if sr_dst is None or sr_dst == sr_src:
        return x, sr_src
    return torchaudio.functional.resample(x, sr_src, sr_dst), sr_dst

def safe_save_wav(path, wav: torch.Tensor, sr: int):
    path.parent.mkdir(parents=True, exist_ok=True)
    x = wav.detach().cpu().numpy().T  # (T, C)
    sf.write(str(path), x, sr, subtype="PCM_16")

def write_json(path: Path, data: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(data, f, indent=2)

def append_manifest(row: dict, manifest_path: Path):
    df = pd.DataFrame([row])
    if manifest_path.exists():
        df.to_csv(manifest_path, mode="a", header=False, index=False)
    else:
        df.to_csv(manifest_path, index=False)


In [None]:
#@title Load pretrained Demucs
from demucs.pretrained import get_model

device = INF.device
model = get_model(INF.model_name).to(device).eval()
print(f"Loaded model: {INF.model_name} on {device}")


Downloading: "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th" to /root/.cache/torch/hub/checkpoints/955717e8-8726e21a.th


100%|██████████| 80.2M/80.2M [00:00<00:00, 196MB/s]


Loaded model: htdemucs on cuda


In [None]:
# --- REPLACE your "Run inference over MUSDB" cell with this ---

from demucs.apply import apply_model
from tqdm import tqdm

def np_to_torch_wave(np_audio):
    """
    musdb gives (T, C) float np array in [-1, 1].
    Demucs expects torch (C, T).
    """
    if np_audio.ndim == 1:
        np_audio = np_audio[:, None]
    # (T, C) -> (C, T)
    t = torch.from_numpy(np_audio.astype('float32')).permute(1, 0).contiguous()
    return t

def already_done(track_out_dir: Path) -> bool:
    return (track_out_dir / "vocals.stereo.wav").exists()

VOCAL_INDEX = 3

db = musdb.DB(
    root=str(MUSDB_ROOT),
    subsets=list(INF.subsets),
    is_wav=(CFG['dataset']['kind'] == 'hq')  # True for HQ WAV
)

# Prepare manifest header once
if not INF.manifest_csv.exists():
    append_manifest({
        "track": "name",
        "subset": "split",
        "sr_in": "sr",
        "sr_out": "sr",
        "duration_s": "seconds",
        "vocals_path": "path",
        "acc_path": "path",
        "time_s": "elapsed",
        "model": "model"
    }, INF.manifest_csv)

total_tracks = len(db.tracks)
print("Total tracks to consider:", total_tracks)

errors = []
processed = 0
start_all = time.time()

for track in tqdm(db.tracks, desc="Separating"):
    out_dir = STEMS_DIR / track.name
    if INF.resume and already_done(out_dir):
        processed += 1
        continue

    try:
        # --- Load mixture from musdb object ---
        # track.audio: (T, C)
        mix_np = track.audio                  # NumPy float32/64
        sr = int(track.rate)                  # sampling rate
        mix = np_to_torch_wave(mix_np)        # torch (C, T), C=2 for MUSDB

        # --- Optional chunking for memory ---
        if INF.segment_seconds and INF.segment_seconds > 0:
            seg_len = INF.segment_seconds * sr
            pieces = []
            for start in range(0, mix.shape[-1], seg_len):
                end = min(mix.shape[-1], start + seg_len)
                chunk = mix[:, start:end]
                with torch.inference_mode():
                    out = apply_model(
                        model,
                        chunk.unsqueeze(0).to(device),  # (1, C, Tseg)
                        split=INF.split,
                        overlap=INF.overlap
                    )[0].cpu()  # (nsrc, C, Tseg)
                pieces.append(out)
            sources = torch.cat(pieces, dim=-1)  # (nsrc, C, T)
        else:
            with torch.inference_mode():
                sources = apply_model(
                    model,
                    mix.unsqueeze(0).to(device),       # (1, C, T)
                    split=INF.split,
                    overlap=INF.overlap
                )[0].cpu()                              # (nsrc, C, T)

        vocals = sources[VOCAL_INDEX]                   # (C, T)
        if INF.save_accompaniment:
            # Either residual:
            accomp = mix - vocals
            # Or sum of other sources:
            # accomp = sources[[i for i in range(sources.size(0)) if i != VOCAL_INDEX]].sum(0)

        # --- Resample if requested ---
        vocals, sr_out = ensure_sr(vocals, sr, INF.sample_rate_out)
        if INF.save_accompaniment:
            accomp, _ = ensure_sr(accomp, sr, sr_out)

        # --- Save files ---
        out_dir.mkdir(parents=True, exist_ok=True)
        safe_save_wav(out_dir / "vocals.stereo.wav", vocals, sr_out)
        if INF.write_mono:
            safe_save_wav(out_dir / "vocals.mono.wav", downmix_mono(vocals), sr_out)

        if INF.save_accompaniment:
            safe_save_wav(out_dir / "accompaniment.stereo.wav", accomp, sr_out)
            if INF.write_mono:
                safe_save_wav(out_dir / "accompaniment.mono.wav", downmix_mono(accomp), sr_out)

        # --- Metadata & manifest ---
        duration_s = float(mix.shape[-1]) / sr
        meta = {
            "track": track.name,
            "subset": track.subset,
            "sr_in": sr,
            "sr_out": sr_out,
            "duration_s": duration_s,
            "model": INF.model_name,
            "params": {
                "split": INF.split,
                "overlap": INF.overlap,
                "segment_seconds": INF.segment_seconds
            },
            "paths": {
                "vocals_stereo": str((out_dir / "vocals.stereo.wav").relative_to(PROJECT_DIR)),
                "vocals_mono": str((out_dir / "vocals.mono.wav").relative_to(PROJECT_DIR)) if INF.write_mono else None,
                "acc_stereo": str((out_dir / "accompaniment.stereo.wav").relative_to(PROJECT_DIR)) if INF.save_accompaniment else None,
                "acc_mono": str((out_dir / "accompaniment.mono.wav").relative_to(PROJECT_DIR)) if (INF.save_accompaniment and INF.write_mono) else None,
            }
        }
        if INF.write_metadata_json:
            write_json(out_dir / "demucs_metadata.json", meta)

        append_manifest({
            "track": track.name,
            "subset": track.subset,
            "sr_in": sr,
            "sr_out": sr_out,
            "duration_s": round(duration_s, 3),
            "vocals_path": str(out_dir / ("vocals.mono.wav" if INF.write_mono else "vocals.stereo.wav")),
            "acc_path": str(out_dir / ("accompaniment.mono.wav" if (INF.save_accompaniment and INF.write_mono) else "accompaniment.stereo.wav")) if INF.save_accompaniment else "",
            "time_s": round(time.time() - start_all, 2),
            "model": INF.model_name
        }, INF.manifest_csv)

        processed += 1

    except Exception as e:
        errors.append((track.name, repr(e)))
        print(f"\n⚠️ Error on {track.name}: {e}")

elapsed = time.time() - start_all
print(f"\n✅ Done. Processed {processed}/{total_tracks} tracks in {elapsed/60:.1f} min.")
if errors:
    print(f"Tracks with errors: {len(errors)}")
    for t, msg in errors[:8]:
        print(" -", t, ":", msg)


Total tracks to consider: 150


Separating: 100%|██████████| 150/150 [31:12<00:00, 12.49s/it]


✅ Done. Processed 150/150 tracks in 31.2 min.





In [11]:
#@title Inspect one separated track
import random, soundfile as sf, numpy as np
from IPython.display import Audio

# Pick a track that exists in STEMS_DIR
candidates = sorted([p for p in STEMS_DIR.glob("*") if p.is_dir()])
assert candidates, f"No stems under {STEMS_DIR}. Did inference run?"
pick = random.choice(candidates)
vmono = pick / "vocals.mono.wav"
vst   = pick / "vocals.stereo.wav"

print("Sample track:", pick.name)
print("Exists mono:", vmono.exists(), "| stereo:", vst.exists())

if vmono.exists():
    y, sr = sf.read(str(vmono))
    y = y.squeeze() if y.ndim > 1 else y
    print("Mono shape:", y.shape, "| sr:", sr)
    display(Audio(y[sr*40:sr*50], rate=sr))  # first 10 seconds preview


Sample track: Invisible Familiars - Disturbing Wildlife
Exists mono: True | stereo: True
Mono shape: (9644414,) | sr: 44100
