## Setup

In [None]:
# Remove RAPIDS & TF families that cause dependency conflicts in Colab
!pip -q uninstall -y \
  cudf-cu12 numba cuml-cu12 cugraph-cu12 dask-cudf-cu12 dask-cuda rapids-dask-dependency \
  distributed-ucxx-cu12 ucx-py-cu12 rmm-cu12 cudf-cu12 cuvs-cu12 cupy-cuda12x raft-dask-cu12 nx-cugraph-cu12 \
  pylibcugraph-cu12 pylibraft-cu12 ucxx-cu12 \
  tensorflow tensorflow-text tensorflow-decision-forests tf-keras keras keras-hub tensorflow-hub \
  ddsp crepe || true

# Fresh packaging toolchain
!pip -q install -U pip setuptools wheel build jedi

In [None]:
# Fresh NumPy/SciPy
!pip -q install --no-cache-dir --force-reinstall "numpy==2.1.3" "scipy==1.14.1"

# TensorFlow stack that works on Py3.12 and satisfies DDSP/TFP imports
!pip -q install -U "tensorflow==2.20.0" "tf-keras==2.20.0" "tensorflow-probability==0.25.0"

# Silence CPU runtimes warning if present
!pip -q uninstall -y jax-cuda12-plugin -q || true

# numba compatible with NumPy 2.1.x (needed by librosa/resampy/etc.)
!pip -q install -U "numba==0.62.0"

# Audio + utils
!pip -q install -U librosa soundfile absl-py

# JAX + ecosystem
!pip -q install -U "jax[cpu]>=0.4.28" "flax>=0.8.2" "optax>=0.2.2" chex orbax-checkpoint gin-config clu


# Install DDSP source without its pinned (outdated) requirements
!pip -q install --no-deps "git+https://github.com/magenta/ddsp@main#egg=ddsp"

# f0 extractor (replacement for CREPE)
!pip -q install -U torch torchcrepe

In [1]:
import os
os.environ.setdefault("TF_GPU_ALLOCATOR", "cuda_malloc_async")  # helps fragmentation

import tensorflow as tf
# Let TF allocate GPU memory gradually
for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print("Could not set memory growth:", e)

# Mixed precision halves activation memory on GPU
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

print("Policy:", mixed_precision.global_policy())

Policy: <DTypePolicy "mixed_float16">


In [2]:
# Provide a tiny 'crepe' module so ddsp.losses can import; we won't call it.
import sys, types
_crepe = types.ModuleType("crepe")
def _predict_stub(*args, **kwargs):
    raise RuntimeError("ddsp.losses tried to call crepe.predict; use torchcrepe for f0 instead.")
_crepe.predict = _predict_stub
sys.modules["crepe"] = _crepe

# Smoke test
import numpy as np, librosa, soundfile, torch, torchcrepe
import tensorflow_probability as tfp
import jax, flax, ddsp

print("numpy:", np.__version__)
print("TF / TFP:", tf.__version__, "/", tfp.__version__)
print("jax / flax:", jax.__version__, "/", flax.__version__)
print("ddsp:", ddsp.__version__)

numpy: 2.1.3
TF / TFP: 2.20.0 / 0.25.0
jax / flax: 0.8.0 / 0.12.0
ddsp: 3.7.0


In [3]:
#@title Mount & paths
from google.colab import drive
from pathlib import Path
import yaml, json, math, time

drive.mount('/content/drive', force_remount=True)

PROJECT_DIR = Path('/content/drive/MyDrive/ddsp-demucs')
CFG = yaml.safe_load(open(PROJECT_DIR / 'env' / 'config.yaml'))

TFRECORDS_DIR = Path(CFG['paths']['tfrecords_dir'])
EXP_DIR = Path(CFG['paths']['exp_dir']) / 'run_ddsp_001'
EXP_DIR.mkdir(parents=True, exist_ok=True)

print("TFRecords:", TFRECORDS_DIR)
print("Exp:", EXP_DIR)

Mounted at /content/drive
TFRecords: /content/drive/MyDrive/ddsp-demucs/data/tfrecords
Exp: /content/drive/MyDrive/ddsp-demucs/exp/run_ddsp_001


In [24]:
import glob, os
for p in glob.glob("/content/cache/**/*.lockfile", recursive=True):
    try: os.remove(p)
    except FileNotFoundError: pass


## Training

In [5]:
#@title Globals
# === Globals ===
SR            = 22050
AUDIO_SAMPLES = 88200           # 4 s
TPRIME        = 1000            # conditioning frames
FRAME_RATE    = TPRIME * SR / AUDIO_SAMPLES  # 250.0 fps
HOP_SAMPLES   = int(round(SR / FRAME_RATE))  # 88
N_HARMONICS   = 64
N_NOISE_BANDS = 65
NFFT = 1024
HOP  = int(round(SR / FRAME_RATE))


In [6]:
#@title Torhcrepe-backed f0
import numpy as np, torch, torchcrepe, tensorflow as tf, ddsp
torch.set_num_threads(1)

def _torchcrepe_f0_np(y_np, sr, hop_length, fmin, fmax, periodicity_thresh, _ignored=None):
    # normalize scalars
    sr  = int(sr); hop_length = int(hop_length)
    fmin = float(fmin); fmax = float(fmax); periodicity_thresh = float(periodicity_thresh)

    # audio → [1, T] float32
    y_np = np.asarray(y_np, dtype=np.float32).reshape(-1)
    x = torch.from_numpy(y_np[None, :])  # [1, T]

    with torch.no_grad():
        f0, pd = torchcrepe.predict(
            x, sr, hop_length, fmin, fmax,
            model='tiny',                 # ← hardcoded; avoids bytes entirely
            batch_size=1024, device='cpu', return_periodicity=True,
        )

    f0 = f0.squeeze(0).cpu().numpy()
    pd = pd.squeeze(0).cpu().numpy()
    f0[pd < periodicity_thresh] = 0.0
    return f0.astype(np.float32, copy=False)

def _f0_from_torchcrepe_tf(y_1d, fmin=90.0, fmax=600.0,
                           hop_length=512, periodicity_thresh=0.40):
    f0 = tf.numpy_function(
        _torchcrepe_f0_np,
        [y_1d,
         tf.constant(SR, tf.int32),
         tf.constant(hop_length, tf.int32),
         tf.constant(fmin, tf.float32),
         tf.constant(fmax, tf.float32),
         tf.constant(periodicity_thresh, tf.float32),
         0],  # dummy
        Tout=tf.float32,
    )
    f0 = tf.reshape(f0, [-1])
    f0 = ddsp.core.resample(f0[tf.newaxis, :], TPRIME)[0]
    return tf.ensure_shape(f0, [TPRIME])

In [7]:
#@title Utils (DSP Helpers and Parser)
def _pad_crop_1d(x, length):
    x = tf.reshape(x, [-1])[:length]
    pad = tf.maximum(0, length - tf.shape(x)[0])
    x = tf.pad(x, [[0, pad]])
    x.set_shape([length])
    return x

def _resample_to_sr(x, sr_in, sr_out):
    x = tf.reshape(x, [-1])
    if sr_in == sr_out: return x
    ratio = tf.cast(sr_out, tf.float32) / tf.cast(sr_in, tf.float32)
    new_len = tf.cast(tf.round(tf.cast(tf.shape(x)[0], tf.float32) * ratio), tf.int32)
    return ddsp.core.resample(x[tf.newaxis, :], new_len)[0]

def _to_Tprime(vec_1d):
    vec_1d = tf.reshape(tf.cast(vec_1d, tf.float32), [-1])
    return ddsp.core.resample(vec_1d[tf.newaxis, :], TPRIME)[0]

def _rms_loudness_db(x, eps=1e-8):
    x = tf.reshape(x, [1, -1])
    hop = AUDIO_SAMPLES // TPRIME         # ~= 88
    win = max(256, hop * 2)
    stft = tf.signal.stft(x, frame_length=win, frame_step=hop,
                          window_fn=tf.signal.hann_window, pad_end=True)
    mag = tf.abs(stft)                     # [1, T*, F]
    rms = tf.sqrt(tf.reduce_mean(tf.square(mag), axis=-1) + eps)  # [1, T*]
    ld  = 20.0 * tf.math.log(rms + eps) / tf.math.log(10.0)
    return _to_Tprime(tf.squeeze(ld, 0))   # [T’]

def _logmel_1xTprime(x, n_fft=1024, mel_bins=64):
    x = tf.reshape(x, [1, -1])                              # [1, T]
    hop = AUDIO_SAMPLES // TPRIME                           # ~= 88
    S = tf.abs(tf.signal.stft(
        x, frame_length=n_fft, frame_step=hop,
        window_fn=tf.signal.hann_window, pad_end=True))     # [1, T*, F]  F = n_fft//2 + 1

    mel_fb = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=mel_bins,
        num_spectrogram_bins=n_fft // 2 + 1,
        sample_rate=SR,
        lower_edge_hertz=50.0,
        upper_edge_hertz=SR * 0.45,
    )                                                       # [F, mel_bins]

    M = tf.tensordot(S, mel_fb, axes=[[2], [0]])           # [1, T*, mel_bins]
    M = tf.math.log(M + 1e-5)                               # [1, T*, mel_bins]

    # resample time axis T* -> TPRIME
    M = tf.squeeze(M, 0)                                    # [T*, mel_bins]
    M = ddsp.core.resample(M[tf.newaxis, ...], TPRIME)[0]   # [TPRIME, mel_bins]
    return tf.ensure_shape(M, [TPRIME, mel_bins])           # [T’, mel]
              # [T’, mel]



def _parse_and_features(serialized):
    feats = {
        "audio/inputs":      tf.io.FixedLenFeature([], tf.string),
        "audio/targets":     tf.io.FixedLenFeature([], tf.string),
        "audio/sample_rate": tf.io.FixedLenFeature([], tf.int64),
        "meta/track":        tf.io.FixedLenFeature([], tf.string),
    }
    ex  = tf.io.parse_single_example(serialized, feats)
    xin = tf.io.decode_raw(ex["audio/inputs"], tf.float32)
    ygt = tf.io.decode_raw(ex["audio/targets"], tf.float32)
    sr  = tf.cast(ex["audio/sample_rate"], tf.int32)

    # resample → fixed 4 s window
    xin = _pad_crop_1d(_resample_to_sr(xin, sr, SR), AUDIO_SAMPLES)
    ygt = _pad_crop_1d(_resample_to_sr(ygt, sr, SR), AUDIO_SAMPLES)

    # features
    ld     = _rms_loudness_db(xin)                    # [T’]
    mel_in = _logmel_1xTprime(xin)                    # [T’, 64]
    mel_gt = _logmel_1xTprime(ygt)                    # [T’, 64]
    f0_in = _f0_from_torchcrepe_tf(
        xin, fmin=90.0, fmax=600.0, hop_length=512, periodicity_thresh=0.40
    )                                             # [T’]

    cond = {
        "x_in":        xin,
        "f0_in":       f0_in,
        "loudness_db": ld,
        "mel_in":      mel_in,
        "mel_gt":      mel_gt,
        "track":       ex["meta/track"],
    }
    return cond, ygt


In [8]:
#@title Copy files to Shard
SRC_TRAIN = "/content/drive/MyDrive/ddsp-demucs/data/tfrecords/train"
SRC_VAL   = "/content/drive/MyDrive/ddsp-demucs/data/tfrecords/val"

!mkdir -p /content/data/tfrecords/train /content/data/tfrecords/val
!rsync -ah --info=progress2 "{SRC_TRAIN}/" /content/data/tfrecords/train/
!rsync -ah --info=progress2 "{SRC_VAL}/"   /content/data/tfrecords/val/

import tensorflow as tf
train_files = tf.io.gfile.glob("/content/data/tfrecords/train/*.tfrecord")
val_files   = tf.io.gfile.glob("/content/data/tfrecords/val/*.tfrecord")
train_files.sort(); val_files.sort()
print(len(train_files), "train,", len(val_files), "val")
compression = ""   # <- no gzip


          4.81G 100%   42.53MB/s    0:01:47 (xfr#14, to-chk=0/15)
        553.43M 100%   12.28MB/s    0:00:42 (xfr#4, to-chk=0/5)
14 train, 4 val


In [9]:
#@title Make Cached Slice function
import os, tensorflow as tf

def make_cached_slice(files, K=400, compression="", cache_path=None):
    """
    Build a small, deterministic slice of the TFRecords and (optionally) cache it to disk.

    Args:
      files: list of TFRecord shard paths
      K:     take the first K records across shards (deterministic)
      compression: "" or "GZIP" to match how the TFRecords were written
      cache_path: str or None. If provided, use a single-file cache at this path.
                  Example: "/content/cache/train.tfcache"
    """
    # TFRecordDataset in deterministic single-threaded mode
    ds = tf.data.TFRecordDataset(files, compression_type=compression)

    # Take a prefix deterministically (before parsing)
    if K is not None:
        ds = ds.take(int(K))

    # Optional on-disk cache
    if cache_path:
        # If a directory was passed, create a filename inside it
        if cache_path.endswith("/") or tf.io.gfile.isdir(cache_path):
            tf.io.gfile.makedirs(cache_path)
            cache_file = os.path.join(cache_path, "cache.tfcache")
        else:
            # Ensure parent exists
            parent = os.path.dirname(cache_path)
            if parent:
                tf.io.gfile.makedirs(parent)
            cache_file = cache_path
        ds = ds.cache(cache_file)   # <- single file cache (creates a .lockfile next to it)
    else:
        ds = ds.cache()             # in-RAM cache

    # Parse → features (uses your existing function)
    ds = ds.map(_parse_and_features, num_parallel_calls=1, deterministic=True)

    # Final pipeline
    ds = ds.batch(1, drop_remainder=True).prefetch(1)
    return ds


In [26]:
#@title Model
# --- hard reset setup ---
import os, tensorflow as tf, tensorflow.keras as keras, ddsp
from tensorflow.keras import mixed_precision

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"   # no XLA while stabilizing
mixed_precision.set_global_policy("float32")         # keep dtypes simple

for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass
tf.config.optimizer.set_jit(False)

# --- tiny harmonic-only model (same as before) ---
from ddsp.synths import Harmonic, FilteredNoise
from ddsp.effects import Reverb

N_HARMONICS   = 64
N_NOISE_BANDS = 65
MEL_BINS      = 64
LOUD_GATE_DB  = -60.0
MAX_HZ = 6000.0

class DDSPDecoder(keras.Model):
    def __init__(self, rnn_units=256, mlp_units=(256, 128), **kwargs):
        super().__init__(**kwargs)
        self.f0midi_range = (24.0, 84.0)  # clip to [C1, C6]

        # ⬇️ NEW: encodes [T’, 64] mel into [T’, 32]
        self.mel_enc = keras.Sequential([
            keras.layers.Dense(64, activation='relu'),
            keras.layers.Dense(32, activation='relu'),
        ])

        # feature trunk (Dense auto-infers input dim: 2 + 32)
        self.pre  = keras.layers.Dense(128, activation='relu')
        self.gru  = keras.layers.GRU(rnn_units, return_sequences=True)
        self.post = keras.Sequential(
            [keras.layers.Dense(u, activation='relu') for u in mlp_units]
        )

        self.amp_head   = keras.layers.Dense(1)
        self.harm_head  = keras.layers.Dense(N_HARMONICS)
        self.noise_head = keras.layers.Dense(N_NOISE_BANDS)

        self.reverb = ddsp.effects.Reverb(trainable=True)
        self.harm   = ddsp.synths.Harmonic(sample_rate=SR, amp_resample_method='linear')
        self.noise  = ddsp.synths.FilteredNoise(n_samples=AUDIO_SAMPLES,
                                               scale_fn=ddsp.core.exp_sigmoid, initial_bias=-5.0)


    def call(self, inputs, training=False):
      # small helper to accept multiple possible key names
      def _get(name_options, required=True):
          for k in name_options:
              if k in inputs:
                  return tf.cast(inputs[k], tf.float32)
          if required:
              raise KeyError(f"Missing any of keys {name_options} in inputs.")
          return None

      f0_hz      = _get(["f0_hz", "f0_in"])                 # [B, T’]
      k = tf.ones([5,1,1], tf.float32) / 5.0   # 5-frame box filter
      f0_s = tf.nn.conv1d(f0_hz[..., tf.newaxis], k, stride=1, padding="SAME")[...,0]
      f0_hz = tf.where(f0_hz > 0.0, f0_s, 0.0)  # keep zero on unvoiced
      loudness   = _get(["loudness_db", "ld_db", "loudness"])  # [B, T’]
      # prefer mel_in; else fall back to mel_gt (teacher forcing) or generic "mel"
      mel        = _get(["mel_in", "mel_gt", "mel"], required=False)  # [B, T’, 64] or None
      x_in       = _get(["x_in"], required=False)            # optional; may be None


      # (keep your existing rank guards)
      if f0_hz.shape.rank == 1: f0_hz = f0_hz[tf.newaxis, :]
      if loudness.shape.rank == 1: loudness = loudness[tf.newaxis, :]

      # if mel provided, ensure [B, T’, 64]
      if mel is not None and mel.shape.rank == 2:
          mel = mel[tf.newaxis, ...]

      if mel is not None:
          if mel.shape.rank == 2: mel = mel[tf.newaxis, ...]     # [1,T’,64]
          mel_feat = self.mel_enc(mel)                            # [B,T’,32]
      else:
          Tprime = tf.shape(f0_hz)[1]
          mel_feat = tf.zeros([tf.shape(f0_hz)[0], Tprime, 32], tf.float32)

      # features → trunk
      f0_midi = ddsp.core.hz_to_midi(tf.clip_by_value(f0_hz, 1.0, 8000.0))
      f0_midi = tf.clip_by_value(f0_midi, *self.f0midi_range)
      x = tf.stack([f0_midi, loudness], axis=-1)                     # [B,T’,2]
      x = tf.concat([x, mel_feat], axis=-1)                       # [B,T’,34]
      x = self.pre(x); x = self.gru(x); x = self.post(x)

      # heads
      amp   = ddsp.core.exp_sigmoid(self.amp_head(x))             # [B,T’,1]
      harmd = tf.nn.softmax(self.harm_head(x), axis=-1)           # [B,T’,H]

      noise = ddsp.core.exp_sigmoid(self.noise_head(x))           # [B,T’,BANDS]

      TP = tf.shape(f0_hz)[1]
      # After computing amp, harmd, noise, f0_hz
      amp   = tf.ensure_shape(amp,   [None, None, 1])
      harmd = tf.ensure_shape(harmd, [None, None, N_HARMONICS])
      noise = tf.ensure_shape(noise, [None, None, N_NOISE_BANDS])
      f0_hz = tf.ensure_shape(f0_hz, [None, None])

      voiced = tf.cast(f0_hz > 0.0, tf.float32)[..., tf.newaxis]  # [B,T’,1]
      amp = amp * voiced * 0.8  # zero out harmonic amp where unvoiced

      # harmd: [B,T’,H], f0_hz: [B,T’]
      H = tf.shape(harmd)[-1]
      idx = tf.linspace(1.0, tf.cast(H, tf.float32), H)  # [H]
      idx = idx[tf.newaxis, tf.newaxis, :]               # [1,1,H]
      harm_freq = f0_hz[..., tf.newaxis] * idx           # [B,T’,H]
      mask = tf.cast(harm_freq <= (SR/2.0 - 200.0), tf.float32)  # small safety margin
      harmd = harmd * mask



     # --- synthesis ---
      # 1) Harmonic synthesis (may return 64k depending on internal hop assumptions)
      f0_3d = f0_hz[..., tf.newaxis]        # [B, T', 1]
      audio_h = self.harm(
          amplitudes=amp,                   # keep [B, T', 1]
          harmonic_distribution=harmd,      # [B, T', H]
          f0_hz=f0_3d                       # [B, T', 1]
      )                                     # -> [B, Nh] (Nh could be 64000)

      # 2) Noise synthesis — first align magnitudes’ frame count to noise path
      hop_noise = tf.cast(tf.math.ceil(AUDIO_SAMPLES / TPRIME), tf.int32)  # e.g., 89
      t_noise   = tf.cast(tf.math.ceil(AUDIO_SAMPLES / hop_noise), tf.int32)  # e.g., 992
      noise_mags_match = ddsp.core.resample(noise, t_noise)
      audio_n = self.noise(magnitudes=noise_mags_match)                      # -> [B, Nn]

      # 3) Force both waveforms to the exact target length
      def _resample_to_len(x, n):
          # x: [B, N] or [B, N, 1]  →  return [B, n]
          if x.shape.rank == 3 and x.shape[-1] == 1:
              x = tf.squeeze(x, axis=-1)
          x = ddsp.core.resample(x, n)     # linear 1-D resample along time axis
          return x

      audio_h = _resample_to_len(audio_h, AUDIO_SAMPLES)    # [B, AUDIO_SAMPLES]
      audio_n = _resample_to_len(audio_n, AUDIO_SAMPLES)    # [B, AUDIO_SAMPLES]

      # 4) Mix + reverb
      audio = audio_h + audio_n                              # shapes now identical
      audio = self.reverb(audio)                             # [B, AUDIO_SAMPLES]
      return audio




from ddsp.losses import SpectralLoss
# ---- Loss helpers -----------------------------------------------------------

# Use a gentler spectral mix (reduce log-mag + deltas a bit)
spec_loss = SpectralLoss(
    fft_sizes=(2048, 1024, 512, 256, 128, 64),
    loss_type='L1',
    mag_weight=1.0,
    logmag_weight=0.2,     # was 1.0 → too aggressive, tends to “fizz”
    delta_freq_weight=0.2, # was 0.5
    delta_time_weight=0.05 # was 0.1
)

def _mel_filterbank(n_fft, mel_bins, sr):
    # Build once and reuse; shape [F, M]
    fb = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=mel_bins,
        num_spectrogram_bins=n_fft // 2 + 1,
        sample_rate=sr,
        lower_edge_hertz=50.0,
        upper_edge_hertz=sr * 0.45,
    )
    return tf.cast(fb, tf.float32)

# Precompute mel filterbank once, as a tf.constant
MEL_FB = tf.constant(
    tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=MEL_BINS,
        num_spectrogram_bins=NFFT // 2 + 1,
        sample_rate=SR,
        lower_edge_hertz=50.0,
        upper_edge_hertz=SR * 0.45,
    ),
    dtype=tf.float32
)

@tf.function
def mel_spec(y, n_fft=NFFT, hop=HOP, mel_fb=MEL_FB):
    # y: [B, T]
    y = tf.cast(y, tf.float32)
    S = tf.abs(tf.signal.stft(
        y, frame_length=n_fft, frame_step=hop, fft_length=n_fft,
        window_fn=tf.signal.hann_window, pad_end=True
    ))  # [B, Tm, F]
    # project to mel with einsum to avoid shape gotchas
    Y = tf.einsum('btf,fm->btm', S, mel_fb)  # [B, Tm, M]
    return tf.math.log(Y + 1e-5)

@tf.function
def mel_l1(y_true, y_pred, mel_fb=MEL_FB):
    Yt = mel_spec(y_true, mel_fb=mel_fb)
    Yp = mel_spec(y_pred, mel_fb=mel_fb)
    n = tf.minimum(tf.shape(Yt)[1], tf.shape(Yp)[1])
    return tf.reduce_mean(tf.abs(Yt[:, :n, :] - Yp[:, :n, :]))

@tf.function
def spec_centroid(y, n_fft=NFFT, hop=HOP):
    y = tf.cast(y, tf.float32)
    S = tf.abs(tf.signal.stft(
        y, frame_length=n_fft, frame_step=hop, fft_length=n_fft,
        window_fn=tf.signal.hann_window, pad_end=True
    ))  # [B,T,F]
    freqs = tf.linspace(0.0, tf.cast(SR, tf.float32)/2.0, n_fft//2 + 1)  # [F]
    num = tf.reduce_sum(S * freqs[tf.newaxis, tf.newaxis, :], axis=-1)   # [B,T]
    den = tf.reduce_sum(S + 1e-8, axis=-1)                               # [B,T]
    c   = num / (den + 1e-8)                                             # [B,T]
    return tf.reduce_mean(c, axis=-1)                                    # [B]

@tf.function
def centroid_l1(y_true, y_pred):
    ct = spec_centroid(y_true); cp = spec_centroid(y_pred)
    return tf.reduce_mean(tf.abs(ct - cp)) / (SR/2.0)

@tf.function
def si_sdr_loss(y_true, y_pred, eps=1e-8):
    # Scale-invariant SDR as a *loss* (lower is better), averaged over batch
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    # project y_pred onto y_true
    dot = tf.reduce_sum(y_true * y_pred, axis=-1, keepdims=True)
    den = tf.reduce_sum(y_true * y_true, axis=-1, keepdims=True) + eps
    s_target = dot / den * y_true
    e_noise  = y_pred - s_target
    num = tf.reduce_sum(s_target**2, axis=-1) + eps
    den = tf.reduce_sum(e_noise**2,  axis=-1) + eps
    si_sdr = 10.0 * tf.math.log(num / den) / tf.math.log(10.0)
    return -tf.reduce_mean(si_sdr)  # we minimize

# ---- Trainer ---------------------------------------------------------------

class DDSPTrainer(keras.Model):
    def __init__(self, net, mel_w=1.0, cent_w=0.05, sisdr_w=0.0):
        super().__init__()
        self.net = net
        self.mel_w, self.cent_w, self.sisdr_w = mel_w, cent_w, sisdr_w

        # Metrics: keep one Mean per scalar you want to report
        self.m_train_total = keras.metrics.Mean(name="loss")
        self.m_val_total   = keras.metrics.Mean(name="val_loss")
        self.m_spec        = keras.metrics.Mean(name="spec")
        self.m_mel         = keras.metrics.Mean(name="mel")
        self.m_cent        = keras.metrics.Mean(name="cent")
        self.m_sisdr       = keras.metrics.Mean(name="sisdr")

    @property
    def metrics(self):
        # Keras will reset these between epochs
        return [self.m_train_total, self.m_val_total, self.m_spec, self.m_mel, self.m_cent, self.m_sisdr]

    def build(self, _=None):
        self.built = True

    def compile(self, optimizer):
        super().compile()
        self.optimizer = optimizer

    @tf.function
    def _component_losses(self, y_t, y_p):
        Ls = spec_loss(y_t, y_p)        # scalar
        Lm = mel_l1(y_t, y_p)           # scalar
        Lc = centroid_l1(y_t, y_p)      # scalar
        Ld = si_sdr_loss(y_t, y_p) if self.sisdr_w > 0.0 else tf.constant(0.0, tf.float32)
        total = Ls + self.mel_w*Lm + self.cent_w*Lc + self.sisdr_w*Ld
        return Ls, Lm, Lc, Ld, total

    @tf.function
    def train_step(self, data):
        cond, target = data
        with tf.GradientTape() as tape:
            pred = self.net(cond, training=True)
            n = tf.minimum(tf.shape(pred)[1], tf.shape(target)[1])
            x_in = tf.cast(cond["x_in"][:, :tf.shape(pred)[1]], tf.float32)  # align to pred length
            y_t  = tf.cast(target[:,    :tf.shape(pred)[1]], tf.float32)
            y_p  = x_in + tf.cast(pred, tf.float32)  # predict correction
            # then pass y_t and y_p to your component losses
            Ls, Lm, Lc, Ld, total = self._component_losses(y_t, y_p)
        grads = tape.gradient(total, self.net.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.net.trainable_variables))
        if hasattr(self.net, "global_step"):
            self.net.global_step.assign_add(1)

        # Update metrics with SCALARS
        self.m_train_total.update_state(total)
        self.m_spec.update_state(Ls)
        self.m_mel.update_state(Lm)
        self.m_cent.update_state(Lc)
        self.m_sisdr.update_state(Ld)

        # What shows in keras logs per batch/epoch
        return {
            "loss": self.m_train_total.result(),
            "spec": self.m_spec.result(),
            "mel":  self.m_mel.result(),
            "cent": self.m_cent.result(),
            "sisdr": self.m_sisdr.result(),
        }

    @tf.function
    def test_step(self, data):
        cond, target = data
        pred = self.net(cond, training=False)
        n = tf.minimum(tf.shape(pred)[1], tf.shape(target)[1])
        x_in = tf.cast(cond["x_in"][:, :tf.shape(pred)[1]], tf.float32)  # align to pred length
        y_t  = tf.cast(target[:,    :tf.shape(pred)[1]], tf.float32)
        y_p  = x_in + tf.cast(pred, tf.float32)  # predict correction
        # then pass y_t and y_p to your component losses
        Ls, Lm, Lc, Ld, total = self._component_losses(y_t, y_p)

        # Validation metrics
        self.m_val_total.update_state(total)
        self.m_spec.update_state(Ls)
        self.m_mel.update_state(Lm)
        self.m_cent.update_state(Lc)
        self.m_sisdr.update_state(Ld)

        # Keras expects 'val_loss' to exist for callbacks/monitoring
        return {
            "val_loss": self.m_val_total.result(),
            "spec": self.m_spec.result(),
            "mel":  self.m_mel.result(),
            "cent": self.m_cent.result(),
            "sisdr": self.m_sisdr.result(),
        }

In [27]:
#@title Building the trainer
model   = DDSPDecoder()
trainer = DDSPTrainer(model, mel_w=1.0, cent_w=0.03, sisdr_w=0.0)
opt = keras.optimizers.Adam(learning_rate=1e-4, clipnorm=1.0)


def cosine_decay(e):
    # 1 epoch warmup → cosine to 3e-5
    base, minlr = 3e-4, 3e-5
    if e < 1: return base * (0.2 + 0.8 * e)   # warmup to ~3e-4
    t = (e-1)/max(1, (10-1))                  # if you train ~10 epochs
    return minlr + 0.5*(base-minlr)*(1+tf.cos(np.pi*tf.clip_by_value(t,0.0,1.0)))

lr_cb = keras.callbacks.LearningRateScheduler(lambda e, lr: float(cosine_decay(e)))

# dev slice (fast)
train_run = make_cached_slice(train_files, K=800, compression="",
                              cache_path="/content/cache/train/")
val_run   = make_cached_slice(val_files,   K=240, compression="",
                              cache_path="/content/cache/val/")

STEPS_PER_EPOCH = 400
VAL_STEPS       = 120
EPOCHS          = 10

# build
# one real batch
cond0, y0 = next(iter(train_run.take(1)))
# call the net (not trainer) with the fields it expects
_ = model({
    "f0_hz":      cond0["f0_in"],
    "loudness_db":cond0["loudness_db"],
    "mel_in":     cond0["mel_in"],
}, training=False)

# mark the trainer as built (no variables of its own)
trainer.built = True


from pathlib import Path

# --- Run folders (LOCAL for speed) ---
RUN_NAME  = "run_ddsp_001"
BASE_DIR  = Path("/content/exp") / RUN_NAME
CKPT_DIR  = BASE_DIR / "ckpt"
LOG_DIR   = BASE_DIR / "tb"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

ckpt_path = str(CKPT_DIR / "ddsp.weights.h5")  # for SaveNetWeights

class SaveNetWeights(keras.callbacks.Callback):
    def __init__(self, path, monitor="val_loss", mode="min"):
        super().__init__()
        self.path = path
        self.monitor = monitor
        self.mode = mode
        self.best = float("inf") if mode == "min" else -float("inf")

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        val = logs.get(self.monitor)
        if val is None:
            return  # no validation this epoch
        improve = (val < self.best) if self.mode == "min" else (val > self.best)
        if improve:
            self.best = val
            # save underlying net, not the trainer
            self.model.net.save_weights(self.path)
            print(f"\nSaved best net weights to {self.path} (best {self.monitor}: {val:.4f})")

callbacks = [
    keras.callbacks.ModelCheckpoint(ckpt_path, save_weights_only=True,
                                    monitor="val_val_loss", mode="min",
                                    save_best_only=True),
    keras.callbacks.EarlyStopping(monitor="val_val_loss", mode="min",
                                  patience=4, restore_best_weights=True),
]


trainer.compile(optimizer=opt)

In [28]:
#@title Fitting
EPOCHS = 3
history = trainer.fit(
    train_run,
    validation_data=val_run,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VAL_STEPS,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1,
)

Epoch 1/3
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m473s[0m 1s/step - cent: 0.0737 - loss: 5.5507 - mel: 1.6696 - sisdr: 0.0000e+00 - spec: 3.8789 - val_cent: 0.0628 - val_mel: 1.2020 - val_sisdr: 0.0000e+00 - val_spec: 3.3172 - val_val_loss: 4.5211
Epoch 2/3
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m448s[0m 1s/step - cent: 0.0964 - loss: 6.2612 - mel: 2.0084 - sisdr: 0.0000e+00 - spec: 4.2499 - val_cent: 0.0704 - val_mel: 1.2298 - val_sisdr: 0.0000e+00 - val_spec: 3.3624 - val_val_loss: 4.5943
Epoch 3/3




[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 260ms/step - cent: 0.0964 - loss: 6.2612 - mel: 2.0084 - sisdr: 0.0000e+00 - spec: 4.2499 - val_cent: 0.0735 - val_mel: 1.2281 - val_sisdr: 0.0000e+00 - val_spec: 3.3572 - val_val_loss: 4.5875


In [29]:
#@title Objective Evaluation
import numpy as np
import tensorflow as tf
from IPython.display import Audio, display

def si_sdr(ref, est):
    # robust to tiny DC, use float64 for dot products
    ref = ref.astype(np.float64) - np.mean(ref)
    est = est.astype(np.float64) - np.mean(est)
    a = np.dot(est, ref) / (np.dot(ref, ref) + 1e-12)
    e_true = a * ref
    e_res  = est - e_true
    return 10.0 * np.log10((np.sum(e_true**2) + 1e-12) / (np.sum(e_res**2) + 1e-12))

def spec_conv(a, b, n_fft=1024, hop=256):
    A = np.abs(tf.signal.stft(a, n_fft, hop, n_fft).numpy())
    B = np.abs(tf.signal.stft(b, n_fft, hop, n_fft).numpy())
    return np.linalg.norm(A - B) / (np.linalg.norm(A) + 1e-12)

def _np1d(t):  # (B,T) / (T,) tensor -> (T,) np
    x = t.numpy()
    return np.asarray(x).reshape(-1)

sdrs, scs = [], []
listen = []  # keep a couple samples to audition

for k, (cond, tgt) in enumerate(val_run.take(16)):
    # forward
    pred = model({
        "f0_hz":       cond["f0_in"],
        "loudness_db": cond["loudness_db"],
        "mel_in":      cond["mel_in"],
        "x_in":        cond["x_in"],
    }, training=False)

    # tensors -> 1D numpy
    y   = _np1d(tgt)
    p   = _np1d(pred)
    xin = _np1d(cond["x_in"][:, :tf.shape(pred)[1]])

    # residual output = input + correction
    yhat = xin + p

    # align (safety)
    n = min(y.shape[0], yhat.shape[0])
    y   = y[:n]
    yhat= yhat[:n]

    sdrs.append(si_sdr(y, yhat))
    scs.append(spec_conv(y, yhat))

    if k < 3:  # keep a few to listen
        listen.append(("pred",  yhat, 22050))
        listen.append(("input", xin[:n], 22050))
        listen.append(("target",y, 22050))

print(f"Val SI-SDR median: {np.median(sdrs):.2f} dB")
print(f"Val SpectralConv median: {np.median(scs):.3f}")

# quick audition
for tag, w, sr in listen:
    print(tag, "rms=", float(np.sqrt(np.mean(w*w))+1e-12))
    display(Audio(w, rate=sr))

Val SI-SDR median: 3.23 dB
Val SpectralConv median: 0.535
pred rms= 0.05987080931663513


input rms= 0.05203075334429741


target rms= 0.052850883454084396


pred rms= 0.06370053440332413


input rms= 0.05486690253019333


target rms= 0.056081682443618774


pred rms= 0.06168686971068382


input rms= 0.05262365564703941


target rms= 0.0539819672703743


In [19]:
#@title ID Sample
from IPython.display import Audio, display
import numpy as np

def play_idx(ds, idx, model, sr=SR):
    # get the idx-th item
    for i, (cond, tgt) in enumerate(ds.skip(idx).take(1)):
        pred = model({
          "f0_hz": cond["f0_in"],
          "loudness_db": cond["loudness_db"],
          "mel": cond["mel_in"],
          "x_in": cond["x_in"],
        }, training=False)
        y = tgt[0].numpy().astype(np.float32)
        p = cond["x_in"][:, :tf.shape(pred)[1]] + pred
        display(Audio(y, rate=sr))
        display(Audio(p, rate=sr))
        break

play_idx(val_run, idx=81, model=model)     # specific

## Freeze

In [31]:
RUN_TAG   = "ddsp_vocals_residual_v1"
EXPORT_DIR = f"/content/exports/{RUN_TAG}"
SR         = 22050
N_EVAL     = 48  # how many val examples to score/export

import os, json, random, numpy as np, tensorflow as tf
os.makedirs(EXPORT_DIR, exist_ok=True)

cfg = dict(
    run_tag=RUN_TAG,
    sr=SR,
    n_eval=N_EVAL,
    model_params=int(sum(np.prod(v.shape) for v in model.trainable_variables)),
)
with open(os.path.join(EXPORT_DIR, "config.json"), "w") as f:
    json.dump(cfg, f, indent=2)
print(cfg)

{'run_tag': 'ddsp_vocals_residual_v1', 'sr': 22050, 'n_eval': 48, 'model_params': 470626}


In [32]:
#@title Metrics Helpers
import numpy as np
import tensorflow as tf

def si_sdr(ref, est):
    ref = ref.astype(np.float64) - np.mean(ref)
    est = est.astype(np.float64) - np.mean(est)
    den = np.dot(ref, ref) + 1e-12
    a   = np.dot(est, ref) / den
    e_true = a * ref
    e_res  = est - e_true
    return 10.0*np.log10((np.sum(e_true**2)+1e-12)/(np.sum(e_res**2)+1e-12))

def spec_conv(a, b, n_fft=1024, hop=256):
    A = np.abs(tf.signal.stft(a, n_fft, hop, n_fft).numpy())
    B = np.abs(tf.signal.stft(b, n_fft, hop, n_fft).numpy())
    return np.linalg.norm(A - B) / (np.linalg.norm(A) + 1e-12)

def np1d(t):
    x = t.numpy()
    return np.asarray(x).reshape(-1)

def voiced_pct(f0):
    return float(np.mean((np.asarray(f0).reshape(-1) > 0.0))) * 100.0


In [33]:
#@title Evaluation
import csv, math, os
from scipy.io import wavfile
from IPython.display import Audio, display

csv_path = os.path.join(EXPORT_DIR, "eval_summary.csv")
fields = ["idx","track","voiced_pct","rms_in","rms_pred","rms_tgt",
          "si_sdr_baseline","si_sdr_model","spec_conv_baseline","spec_conv_model",
          "len_samples"]

rows = []
sdrs_model, scs_model = [], []
sdrs_base,  scs_base  = [], []

it = iter(val_run.take(N_EVAL))
for k in range(N_EVAL):
    try:
        cond, tgt = next(it)
    except StopIteration:
        break

    # forward
    pred = model({
        "f0_hz":       cond["f0_in"],
        "loudness_db": cond["loudness_db"],
        "mel_in":      cond["mel_in"],
        "x_in":        cond["x_in"],
    }, training=False)

    # tensors -> numpy
    y   = np1d(tgt).astype(np.float32)
    p   = np1d(pred).astype(np.float32)
    xin = np1d(cond["x_in"][:, :tf.shape(pred)[1]]).astype(np.float32)

    # residual output
    yhat = xin + p

    # align
    n = min(len(y), len(yhat), len(xin))
    y, yhat, xin = y[:n], yhat[:n], xin[:n]

    # metrics
    sdr_b = si_sdr(y, xin)
    sdr_m = si_sdr(y, yhat)
    sc_b  = spec_conv(y, xin)
    sc_m  = spec_conv(y, yhat)

    sdrs_base.append(sdr_b);   sdrs_model.append(sdr_m)
    scs_base.append(sc_b);     scs_model.append(sc_m)

    # exports
    trk = cond["track"].numpy()[0].decode("utf-8") if "track" in cond else f"val_{k}"
    stem = f"{k:03d}_{trk.replace(' ','_')[:40]}"
    for tag, audio in [("input",xin),("pred",yhat),("target",y)]:
        wavfile.write(os.path.join(EXPORT_DIR, f"{stem}_{tag}.wav"), SR,
                      np.clip(audio, -1.0, 1.0))

    rows.append(dict(
        idx=k, track=trk, voiced_pct=voiced_pct(cond["f0_in"]),
        rms_in=float(np.sqrt(np.mean(xin*xin))+1e-12),
        rms_pred=float(np.sqrt(np.mean(yhat*yhat))+1e-12),
        rms_tgt=float(np.sqrt(np.mean(y*y))+1e-12),
        si_sdr_baseline=float(sdr_b), si_sdr_model=float(sdr_m),
        spec_conv_baseline=float(sc_b), spec_conv_model=float(sc_m),
        len_samples=int(n),
    ))

with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=fields)
    w.writeheader(); w.writerows(rows)

print(f"wrote {len(rows)} rows → {csv_path}")
print(f"Median SI-SDR baseline: {np.median(sdrs_base):.2f} dB")
print(f"Median SI-SDR model:    {np.median(sdrs_model):.2f} dB")
print(f"Median SpecConv base:   {np.median(scs_base):.3f}")
print(f"Median SpecConv model:  {np.median(scs_model):.3f}")

# listen to one triplet
if rows:
    ex = rows[min(3, len(rows)-1)]
    base = os.path.join(EXPORT_DIR, f"{ex['idx']:03d}_{ex['track'].replace(' ','_')[:40]}")
    print("Preview:", ex['track'])
    display(Audio(os.path.join(EXPORT_DIR, base + "_input.wav"),  rate=SR))
    display(Audio(os.path.join(EXPORT_DIR, base + "_pred.wav"),   rate=SR))
    display(Audio(os.path.join(EXPORT_DIR, base + "_target.wav"), rate=SR))


wrote 48 rows → /content/exports/ddsp_vocals_residual_v1/eval_summary.csv
Median SI-SDR baseline: 10.65 dB
Median SI-SDR model:    4.99 dB
Median SpecConv base:   0.209
Median SpecConv model:  0.435
Preview: Auctioneer - Our Future Faces


In [34]:
#@title Quick Ablation
def eval_with_overrides(n_eval=16, drop_mel=False, drop_f0=False):
    sdrs = []
    for k, (cond, tgt) in enumerate(val_run.take(n_eval)):
        feed = {
            "f0_hz":       (tf.zeros_like(cond["f0_in"]) if drop_f0 else cond["f0_in"]),
            "loudness_db": cond["loudness_db"],
            "mel_in":      (tf.zeros_like(cond["mel_in"]) if drop_mel else cond["mel_in"]),
            "x_in":        cond["x_in"],
        }
        pred = model(feed, training=False)
        y = np1d(tgt).astype(np.float32)
        p = np1d(pred).astype(np.float32)
        xin = np1d(cond["x_in"][:, :tf.shape(pred)[1]]).astype(np.float32)
        yhat = (xin + p)[:min(len(y), len(p), len(xin))]
        y    = y[:len(yhat)]
        sdrs.append(si_sdr(y, yhat))
    return float(np.median(sdrs)) if sdrs else float("nan")

print("Median SI-SDR (eval, normal):", eval_with_overrides(24, False, False))
print("Median SI-SDR (no mel):      ", eval_with_overrides(24, True,  False))
print("Median SI-SDR (no f0):       ", eval_with_overrides(24, False, True))


Median SI-SDR (eval, normal): 3.4834501839464327
Median SI-SDR (no mel):       3.603918420346659
Median SI-SDR (no f0):        11.372660720393917
