In [3]:
#@title Mount Drive & load config
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

from pathlib import Path
import yaml, json

PROJECT_DIR = Path('/content/drive/MyDrive/ddsp-demucs')
CFG = yaml.safe_load(open(PROJECT_DIR / 'env' / 'config.yaml'))

STEMS_DIR      = Path(CFG['paths']['stems_dir'])        # e.g., data/stems/demucs_htdemucs44k
FEATURES_DIR   = Path(CFG['paths']['features_dir'])     # e.g., data/features
TFRECORDS_DIR  = Path(CFG['paths']['tfrecords_dir'])    # e.g., data/tfrecords
TFRECORDS_DIR.mkdir(parents=True, exist_ok=True)

MUSDB_ROOT = Path(CFG['dataset']['root'])               # symlink to musdb18_hq/
print("Project:", PROJECT_DIR)
print("Stems:", STEMS_DIR)
print("Features:", FEATURES_DIR)
print("TFRecords out:", TFRECORDS_DIR)
print("MUSDB root:", MUSDB_ROOT)

Mounted at /content/drive
Project: /content/drive/MyDrive/ddsp-demucs
Stems: /content/drive/MyDrive/ddsp-demucs/data/stems/demucs_htdemucs44k
Features: /content/drive/MyDrive/ddsp-demucs/data/features
TFRecords out: /content/drive/MyDrive/ddsp-demucs/data/tfrecords
MUSDB root: /content/drive/MyDrive/ddsp-demucs/data/musdb18hq


In [None]:
#@title TF setup (Lite 2.20) — uninstall extras that pin TF 2.19
import sys, subprocess, re, os

def sh(cmd):
    print(">", cmd)
    return subprocess.run(cmd, shell=True, check=False, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout

# Remove extras that force TF 2.19
print(sh("pip -q uninstall -y tensorflow-decision-forests tensorflow-text tf-keras"))

# Install/confirm TF 2.20
print(sh('pip -q install -U "tensorflow==2.20.0"'))

# Verify and hard-restart to load correct TF binary cleanly
import tensorflow as tf
print("TF version now:", tf.__version__)
# Force a clean restart so the loaded TF libs match the just-installed wheel
import os; os.kill(os.getpid(), 9)


> pip -q uninstall -y tensorflow-decision-forests tensorflow-text tf-keras

> pip -q install -U "tensorflow==2.20.0"
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 620.7/620.7 MB 2.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 112.5 MB/s eta 0:00:00
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dopamine-rl 4.1.2 requires tf-keras>=2.18.0, which is not installed.



In [1]:
#@title Minimal deps after TF setup
!pip -q install musdb stempeg soundfile librosa tqdm -U

import tensorflow as tf
import numpy as np, pandas as pd, soundfile as sf, librosa, musdb
from tqdm import tqdm
from dataclasses import dataclass
from pathlib import Path

print("TF:", tf.__version__)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/963.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m962.6/963.0 kB[0m [31m36.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.0/963.0 kB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[?25hTF: 2.20.0


In [4]:
#@title Configure TFRecord building
@dataclass
class BuildCfg:
    # choose which accepted list to consume
    accepted_csv: Path = FEATURES_DIR / "accepted_segments_quick.csv"   # or "accepted_segments_quota_detail.csv"
    # final training SR
    train_sr: int = 22050
    # segment windowing inside each accepted region (creates many examples)
    win_s: float = 4.0
    hop_s: float = 1.0
    # sharding
    examples_per_shard: int = 512
    # split strategy: use MUSDB split (train/test) and within train make val split
    # val split (by track) ratio:
    val_ratio: float = 0.1
    # sanity: min voiced seconds per window (skip near-silence windows)
    min_rms_db: float = -50.0  # skip if loudness lower than this (approximate)

BC = BuildCfg()
print("Using accepted segments from:", BC.accepted_csv)

Using accepted segments from: /content/drive/MyDrive/ddsp-demucs/data/features/accepted_segments_quick.csv


In [5]:
#@title Load accepted segments & MUSDB track map
# accepted segments CSV must contain: track,start_s,end_s,mono_fraction,segment_pass
segs = pd.read_csv(BC.accepted_csv)
assert len(segs), f"No segments in {BC.accepted_csv}"
segs = segs[segs.segment_pass == True].copy()
segs["dur_s"] = segs["end_s"] - segs["start_s"]
segs = segs[segs["dur_s"] > 0.1]
print("Accepted segments:", len(segs), "| Unique tracks:", segs.track.nunique())

# MUSDB map: name -> (subset, rate, paths)
db = musdb.DB(root=str(MUSDB_ROOT), subsets=['train','test'], is_wav=True)
name2track = {t.name: t for t in db.tracks}
subset_map = {t.name: t.subset for t in db.tracks}
print("MUSDB tracks seen:", len(name2track))

# Verify stems folders exist for all
missing = [t for t in segs.track.unique() if not (STEMS_DIR / t / "vocals.mono.wav").exists()]
if missing:
    print("⚠️ Missing Demucs mono for tracks:", missing[:5], "... total", len(missing))
    segs = segs[~segs.track.isin(missing)]
    print("Filtered to:", len(segs), "segments")


Accepted segments: 6361 | Unique tracks: 149
MUSDB tracks seen: 150


In [6]:
#@title Helpers (audio IO, slicing, TF Example)
def read_wav_mono(path: Path):
    y, sr = sf.read(str(path), dtype='float32')
    if y.ndim > 1:
        y = y.mean(axis=1)
    return y, sr

def slice_sec(x, sr, start_s, end_s):
    a = int(round(start_s * sr))
    b = int(round(end_s * sr))
    a = max(0, min(a, len(x)))
    b = max(0, min(b, len(x)))
    if b <= a:
        return np.zeros(1, dtype=np.float32)
    return x[a:b]

def resample_if_needed(y, sr, tgt):
    if sr == tgt:
        return y, sr
    return librosa.resample(y, orig_sr=sr, target_sr=tgt), tgt

def rms_db(x, eps=1e-8):
    rms = np.sqrt(np.mean(x.astype(np.float32)**2) + eps)
    return 20*np.log10(rms + eps)

def make_example(x_in, x_tgt, sr, track, start_s, end_s, subset):
    # serialize float32 arrays as bytes
    x_in = np.asarray(x_in, dtype=np.float32)
    x_tgt = np.asarray(x_tgt, dtype=np.float32)
    feat = {
        "audio/inputs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[x_in.tobytes()])),
        "audio/targets": tf.train.Feature(bytes_list=tf.train.BytesList(value=[x_tgt.tobytes()])),
        "audio/sample_rate": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(sr)])),
        "audio/length": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(len(x_in))])),
        "meta/track": tf.train.Feature(bytes_list=tf.train.BytesList(value=[track.encode()])),
        "meta/subset": tf.train.Feature(bytes_list=tf.train.BytesList(value=[subset.encode()])),
        "meta/start_sec": tf.train.Feature(float_list=tf.train.FloatList(value=[float(start_s)])),
        "meta/end_sec": tf.train.Feature(float_list=tf.train.FloatList(value=[float(end_s)])),
        "meta/source": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"demucs_htdemucs->ddsp"]))
    }
    return tf.train.Example(features=tf.train.Features(feature=feat))


In [8]:
#@title Build splits by track (stable)
rng = np.random.default_rng(1337)

tracks = sorted(segs.track.unique())
# assign fixed split: MUSDB test -> "test"; MUSDB train -> split into train/val by track
val_candidates = [t for t in tracks if subset_map.get(t, "train") == "train"]
rng.shuffle(val_candidates)
n_val = int(round(len(val_candidates) * BC.val_ratio))
val_set = set(val_candidates[:n_val])

def split_of(track):
    s = subset_map.get(track, "train")
    if s == "test":
        return "test"
    return "val" if track in val_set else "train"

segs["split"] = segs.track.map(split_of)
segs.split.value_counts()


Unnamed: 0_level_0,count
split,Unnamed: 1_level_1
train,3628
test,2301
val,432


In [9]:
#@title Write sharded TFRecords (train/val/test)
def shard_writer(prefix: Path, split: str, max_per_shard: int):
    prefix.mkdir(parents=True, exist_ok=True)
    count = 0
    shard_idx = 0
    writer = None
    def _open_new():
        nonlocal writer, shard_idx, count
        if writer is not None:
            writer.close()
        shard_path = prefix / f"{split}-{shard_idx:05d}.tfrecord"
        writer = tf.io.TFRecordWriter(str(shard_path))
        shard_idx += 1
        count = 0
    _open_new()
    def write(example):
        nonlocal writer, count
        writer.write(example.SerializeToString())
        count += 1
        if count >= max_per_shard:
            _open_new()
    def close():
        if writer is not None:
            writer.close()
    return write, close

writers = {
    "train": shard_writer(TFRECORDS_DIR / "train", "train", BC.examples_per_shard),
    "val":   shard_writer(TFRECORDS_DIR / "val",   "val",   max(64, BC.examples_per_shard//2)),
    "test":  shard_writer(TFRECORDS_DIR / "test",  "test",  max(64, BC.examples_per_shard//2)),
}

# cache GT vocals per track (avoid re-reading many times)
gt_cache = {}  # name -> (y_mono, sr)

def get_gt_vocals(track_name: str):
    if track_name in gt_cache:
        return gt_cache[track_name]
    mt = name2track.get(track_name, None)
    assert mt is not None, f"Track not found in MUSDB: {track_name}"
    # gt vocals: (T, C), float
    y = mt.targets['vocals'].audio
    if y.ndim == 2:
        y = y.mean(axis=1)
    gt_cache[track_name] = (y.astype(np.float32), int(mt.rate))
    return gt_cache[track_name]

# build
examples_total = {"train":0, "val":0, "test":0}
bad = 0

for track_name, group in tqdm(segs.groupby("track"), desc="Writing TFRecords by track"):
    # Load Demucs mono once
    in_path = STEMS_DIR / track_name / "vocals.mono.wav"
    x_in, sr_in = read_wav_mono(in_path)
    # GT vocals
    x_gt, sr_gt = get_gt_vocals(track_name)

    # Process each accepted segment -> window into fixed examples
    for _, row in group.iterrows():
        start_s, end_s = float(row.start_s), float(row.end_s)
        # Slice seconds
        seg_in = slice_sec(x_in, sr_in, start_s, end_s)
        seg_gt = slice_sec(x_gt, sr_gt, start_s, end_s)

        # Resample both to training SR
        seg_in, _ = resample_if_needed(seg_in, sr_in, BC.train_sr)
        seg_gt, _ = resample_if_needed(seg_gt, sr_gt, BC.train_sr)

        # Window into fixed win_s with hop_s
        N = len(seg_in)
        win = int(round(BC.win_s * BC.train_sr))
        hop = int(round(BC.hop_s * BC.train_sr))
        if N < win:
            # skip tiny segments
            continue
        for a in range(0, N - win + 1, hop):
            b = a + win
            xin = seg_in[a:b]
            xgt = seg_gt[a:b]
            # quick loudness sanity (skip near-silence)
            if rms_db(xin) < BC.min_rms_db or rms_db(xgt) < BC.min_rms_db:
                continue
            example = make_example(xin, xgt, BC.train_sr, track_name, start_s + a/BC.train_sr, start_s + b/BC.train_sr, split_of(track_name))
            write, close = writers[split_of(track_name)]
            write(example)
            examples_total[split_of(track_name)] += 1

# Close writers
for _, (_, close) in writers.items():
    close()

print("✅ Done.")
print("Examples written:", examples_total)


Writing TFRecords by track: 100%|██████████| 149/149 [16:04<00:00,  6.47s/it]


✅ Done.
Examples written: {'train': 6811, 'val': 784, 'test': 4154}
