## Setup

In [4]:
# Remove RAPIDS & TF families that cause dependency conflicts in Colab
!pip -q uninstall -y \
  cudf-cu12 numba cuml-cu12 cugraph-cu12 dask-cudf-cu12 dask-cuda rapids-dask-dependency \
  distributed-ucxx-cu12 ucx-py-cu12 rmm-cu12 cudf-cu12 cuvs-cu12 cupy-cuda12x raft-dask-cu12 nx-cugraph-cu12 \
  pylibcugraph-cu12 pylibraft-cu12 ucxx-cu12 \
  tensorflow tensorflow-text tensorflow-decision-forests tf-keras keras keras-hub tensorflow-hub \
  ddsp crepe || true

# Fresh packaging toolchain
!pip -q install -U pip setuptools wheel build jedi

[0m

In [1]:
# Fresh NumPy/SciPy
!pip -q install --no-cache-dir --force-reinstall "numpy==2.1.3" "scipy==1.14.1"

# TensorFlow stack that works on Py3.12 and satisfies DDSP/TFP imports
!pip -q install -U "tensorflow==2.20.0" "tf-keras==2.20.0" "tensorflow-probability==0.25.0"

# Silence CPU runtimes warning if present
!pip -q uninstall -y jax-cuda12-plugin -q || true

# numba compatible with NumPy 2.1.x (needed by librosa/resampy/etc.)
!pip -q install -U "numba==0.62.0"

# Audio + utils
!pip -q install -U librosa soundfile absl-py

# JAX + ecosystem
!pip -q install -U "jax[cpu]>=0.4.28" "flax>=0.8.2" "optax>=0.2.2" chex orbax-checkpoint gin-config clu


# Install DDSP source without its pinned (outdated) requirements
!pip -q install --no-deps "git+https://github.com/magenta/ddsp@main#egg=ddsp"

# f0 extractor (replacement for CREPE)
!pip -q install -U torch torchcrepe

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
resampy 0.4.3 requires numba>=0.53, which is not installed.
dopamine-rl 4.1.2 requires tensorflow>=2.2.0, which is not installed.
dopamine-rl 4.1.2 requires tf-keras>=2.18.0, which is not installed.
stumpy 1.13.0 requires numba>=0.57.1, which is not installed.
umap-learn 0.5.9.post2 requires numba>=0.51.2, which is not installed.
librosa 0.11.0 requires numba>=0.51.0, which is not installed.
pynndescent 0.5.13 requires numba>=0.51.2, which is not installed.
shap 0.49.1 requires numba>=0.54, which is not installed.[0m[31m
Reason for being yanked: <none given>[0m[33m
[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building whe

In [1]:
import os
os.environ.setdefault("TF_GPU_ALLOCATOR", "cuda_malloc_async")  # helps fragmentation

import tensorflow as tf
# Let TF allocate GPU memory gradually
for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print("Could not set memory growth:", e)

# Mixed precision halves activation memory on GPU
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

print("Policy:", mixed_precision.global_policy())

Policy: <DTypePolicy "mixed_float16">


In [2]:
# Provide a tiny 'crepe' module so ddsp.losses can import; we won't call it.
import sys, types
_crepe = types.ModuleType("crepe")
def _predict_stub(*args, **kwargs):
    raise RuntimeError("ddsp.losses tried to call crepe.predict; use torchcrepe for f0 instead.")
_crepe.predict = _predict_stub
sys.modules["crepe"] = _crepe

# Smoke test
import numpy as np, librosa, soundfile, torch, torchcrepe
import tensorflow_probability as tfp
import jax, flax, ddsp

print("numpy:", np.__version__)
print("TF / TFP:", tf.__version__, "/", tfp.__version__)
print("jax / flax:", jax.__version__, "/", flax.__version__)
print("ddsp:", ddsp.__version__)

numpy: 2.1.3
TF / TFP: 2.20.0 / 0.25.0
jax / flax: 0.8.0 / 0.12.0
ddsp: 3.7.0


In [3]:
#@title Mount & paths
from google.colab import drive
from pathlib import Path
import yaml, json, math, time

drive.mount('/content/drive', force_remount=True)

PROJECT_DIR = Path('/content/drive/MyDrive/ddsp-demucs')
CFG = yaml.safe_load(open(PROJECT_DIR / 'env' / 'config.yaml'))

TFRECORDS_DIR = Path(CFG['paths']['tfrecords_dir'])
EXP_DIR = Path(CFG['paths']['exp_dir']) / 'run_ddsp_001'
EXP_DIR.mkdir(parents=True, exist_ok=True)

print("TFRecords:", TFRECORDS_DIR)
print("Exp:", EXP_DIR)

Mounted at /content/drive
TFRecords: /content/drive/MyDrive/ddsp-demucs/data/tfrecords
Exp: /content/drive/MyDrive/ddsp-demucs/exp/run_ddsp_001


## Training

In [4]:
#@title Training Files
# === Locate TFRecords & define train_files / val_files / compression ===
from pathlib import Path
import yaml

# Guess your project root (adjust if needed)
PROJECT_DIR = Path("/content/drive/MyDrive/ddsp-demucs")

def find_tfrecords_dir():
    cfg_path = PROJECT_DIR / "env" / "config.yaml"
    if cfg_path.exists():
        try:
            CFG = yaml.safe_load(open(cfg_path))
            return Path(CFG["paths"]["tfrecords_dir"])
        except Exception as e:
            print("Config read failed, falling back to default:", e)
    # Fallback used earlier in this notebook
    return PROJECT_DIR / "data" / "tfrecords"

TFRECORDS_DIR = find_tfrecords_dir()

def pick_split_dir(base, *names):
    for name in names:
        d = base / name
        if d.exists():
            return d
    return None

train_dir = pick_split_dir(TFRECORDS_DIR, "train", "training")
val_dir   = pick_split_dir(TFRECORDS_DIR, "val", "valid", "validation", "dev")

def list_shards(d):
    if d is None: return []
    patterns = ["*.tfrecord", "*.tfrecords", "*.tfrecord.gz", "*.tfrecords.gz"]
    files = []
    for pat in patterns:
        files += sorted(str(p) for p in d.glob(pat))
    return files

train_files = list_shards(train_dir)
val_files   = list_shards(val_dir)

def infer_compression(files_a, files_b):
    files = files_a if len(files_a) else files_b
    return "GZIP" if any(f.endswith(".gz") for f in files) else ""

compression = infer_compression(train_files, val_files)

print("TFRECORDS_DIR:", TFRECORDS_DIR)
print("train_dir:", train_dir)
print("val_dir:", val_dir)
print(f"Found {len(train_files)} train shards, {len(val_files)} val shards, compression='{compression}'")
if train_files:
    print("Example train shard:", train_files[0])


TFRECORDS_DIR: /content/drive/MyDrive/ddsp-demucs/data/tfrecords
train_dir: /content/drive/MyDrive/ddsp-demucs/data/tfrecords/train
val_dir: /content/drive/MyDrive/ddsp-demucs/data/tfrecords/val
Found 14 train shards, 4 val shards, compression=''
Example train shard: /content/drive/MyDrive/ddsp-demucs/data/tfrecords/train/train-00000.tfrecord


In [None]:
#@title Torhcrepe-backed f0
# ==== Torchcrepe-backed f0 (fixed for tf.py_function) ====
import numpy as np
import tensorflow as tf


import tensorflow as tf, ddsp, numpy as np

# Training-time globals (pick what you want)
SR         = 16000          # your model SR
FRAME_RATE = 250
WIN_S      = 4.0            # matches the TFRecord writer's BuildCfg.win_s
AUDIO_SAMPLES = int(SR * WIN_S)
TPRIME     = int(FRAME_RATE * WIN_S)
compression = globals().get("compression", "")  # "" or "GZIP"
assert AUDIO_SAMPLES == 64000 and TPRIME == 1000

# Loudness gate for voicing fallback
LOUD_GATE_DB = -60.0

import ddsp

FEATURE_SPEC = {
    "audio/inputs":      tf.io.FixedLenFeature([], tf.string),
    "audio/targets":     tf.io.FixedLenFeature([], tf.string),
    "audio/sample_rate": tf.io.FixedLenFeature([], tf.int64),
    "audio/length":      tf.io.FixedLenFeature([], tf.int64),
    "meta/track":        tf.io.FixedLenFeature([], tf.string),
    "meta/source":       tf.io.FixedLenFeature([], tf.string),
    "meta/start_sec":    tf.io.FixedLenFeature([], tf.float32),
    "meta/end_sec":      tf.io.FixedLenFeature([], tf.float32),
}
# ---- Log-mel from the INPUT (Demucs) ----
MEL_BINS = 64


def _interp_to_len(x, n):
    """Linear resample 1D numpy array x to length n."""
    if len(x) == n:
        return x.astype(np.float32, copy=False)
    xp = np.linspace(0.0, 1.0, num=len(x), dtype=np.float32)
    xq = np.linspace(0.0, 1.0, num=n,       dtype=np.float32)
    return np.interp(xq, xp, x).astype(np.float32)

def _torchcrepe_f0_np(y_in,
                      sr=SR,
                      frame_rate=FRAME_RATE,
                      fmin=80.0, fmax=1000.0,
                      periodicity_thresh=0.45,
                      use_gpu=True,
                      model_size="full"):
    """NumPy in -> NumPy out. Returns f0[TPRIME] in Hz."""
    # tf.py_function hands us an EagerTensor; convert safely to contiguous float32 np array
    y_np = np.asarray(y_in).astype(np.float32, copy=True)
    if y_np.ndim != 1:
        y_np = y_np.reshape(-1)
    if y_np.size == 0:
        return np.zeros((TPRIME,), dtype=np.float32)

    import torch, torchcrepe

    hop = int(round(sr / float(frame_rate)))
    device = 'cuda' if (use_gpu and torch.cuda.is_available()) else 'cpu'

    # torchcrepe expects shape [B, T]
    x = torch.from_numpy(y_np).unsqueeze(0).to(device)

    with torch.no_grad():
        f0, per = torchcrepe.predict(
            audio=x,
            sample_rate=sr,
            hop_length=hop,
            fmin=fmin,
            fmax=fmax,
            model=model_size,         # 'tiny' for speed, 'full' for quality
            batch_size=2048,
            device=device,
            return_periodicity=True
        )
        # Smooth for stability
        per = torchcrepe.filter.median(per, 3)
        per = torchcrepe.filter.mean(per, 3)
        f0  = torchcrepe.filter.median(f0, 3)
        f0  = torchcrepe.filter.mean(f0, 3)

    f0  = f0.squeeze(0).detach().cpu().numpy().astype(np.float32, copy=False)
    per = per.squeeze(0).detach().cpu().numpy().astype(np.float32, copy=False)

    # Periodicity mask: unvoiced -> 0 Hz
    f0[per < periodicity_thresh] = 0.0

    # Ensure exactly TPRIME frames
    f0 = _interp_to_len(f0, TPRIME)
    return f0

def _f0_from_torchcrepe_tf(y_1d,
                           fmin=90.0, fmax=600.0,   # tighter vocal range helps
                           periodicity_thresh=0.45,
                           model_size="full"):
    """TensorFlow wrapper: y_1d [S] -> f0 [TPRIME] float32 via torchcrepe."""
    f0 = tf.py_function(
        func=lambda x: _torchcrepe_f0_np(
            x, sr=SR, frame_rate=FRAME_RATE,
            fmin=fmin, fmax=fmax,
            periodicity_thresh=periodicity_thresh,
            model_size=model_size
        ),
        inp=[y_1d],
        Tout=tf.float32,
    )
    # Fix the static shape for downstream layers
    f0 = tf.ensure_shape(f0, [TPRIME])
    return f0


In [None]:
#@title Parameters and Training Data
def _logmel_1xTprime(y_1d, n_fft=1024):
    hop = tf.cast(tf.math.round(SR / FRAME_RATE), tf.int32)
    S = tf.signal.stft(
        y_1d, frame_length=n_fft, frame_step=hop, fft_length=n_fft,
        window_fn=tf.signal.hann_window, pad_end=True
    )                               # [T, F]
    mag = tf.abs(S)
    mel_fb = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=MEL_BINS,
        num_spectrogram_bins=n_fft // 2 + 1,
        sample_rate=SR,
        lower_edge_hertz=50.0,
        upper_edge_hertz=SR * 0.45,
    )                               # [F, M]
    mel = tf.matmul(mag, mel_fb)    # [T, M]
    mel = tf.math.log(mel + 1e-5)
    mel = tf.squeeze(ddsp.core.resample(mel[tf.newaxis, :, :], TPRIME), 0)  # [T', M]
    # per-sample normalize (helps training)
    m_mean = tf.reduce_mean(mel, axis=-1, keepdims=True)
    m_std  = tf.math.reduce_std(mel, axis=-1, keepdims=True) + 1e-5
    mel_n  = (mel - m_mean) / m_std
    return tf.ensure_shape(mel_n, [TPRIME, MEL_BINS])


def _pad_crop_1d(x, L):
    x = x[:L]
    pad = tf.maximum(0, L - tf.shape(x)[0])
    x = tf.pad(x, [[0, pad]])
    return tf.ensure_shape(x, [L])

def _resample_to_sr(y, sr_in, sr_out=SR):
    y2 = y[tf.newaxis, :]
    rate = tf.cast(sr_out, tf.float32) / tf.cast(sr_in, tf.float32)
    n_out = tf.cast(tf.round(tf.cast(tf.shape(y2)[1], tf.float32) * rate), tf.int32)
    y2 = ddsp.core.resample(y2, n_out)
    return tf.squeeze(y2, 0)

def _rms_loudness_db(y_1d):
    hop = tf.cast(tf.math.round(SR / FRAME_RATE), tf.int32)
    win = tf.maximum(hop * 2, 256)
    frames = tf.signal.frame(y_1d, frame_length=win, frame_step=hop, pad_end=True)
    rms = tf.sqrt(tf.reduce_mean(tf.square(frames), axis=-1) + 1e-12)
    ld_db = 20.0 * tf.math.log(rms + 1e-7) / tf.math.log(10.0)
    ld_db = tf.clip_by_value(ld_db, -120.0, 0.0)
    ld_db = ddsp.core.resample(ld_db[tf.newaxis, :], TPRIME)
    return tf.squeeze(ld_db, 0)

def _parse_and_features(serialized):
    ex = tf.io.parse_single_example(serialized, FEATURE_SPEC)
    L  = tf.cast(ex["audio/length"], tf.int32)
    sr = tf.cast(ex["audio/sample_rate"], tf.int32)

    # float32 decode (your TFRecords are float32)
    x_in  = tf.io.decode_raw(ex["audio/inputs"],  tf.float32)[:L]
    y_tgt = tf.io.decode_raw(ex["audio/targets"], tf.float32)[:L]

    # resample to SR (16k for torchcrepe), crop/pad to WIN_S
    x_in  = _pad_crop_1d(_resample_to_sr(x_in,  sr, SR), AUDIO_SAMPLES)
    y_tgt = _pad_crop_1d(_resample_to_sr(y_tgt, sr, SR), AUDIO_SAMPLES)

    # conditioning from INPUT (Demucs)
    ld_db = _rms_loudness_db(x_in)
    f0_hz = _f0_from_torchcrepe_tf(x_in, fmin=90.0, fmax=600.0, periodicity_thresh=0.40, model_size="tiny")

    # safety: mute f0 on very quiet frames
    f0_hz = tf.where(ld_db > -60.0, f0_hz, tf.zeros_like(f0_hz))


    mel = _logmel_1xTprime(x_in)  # [T', 64]

    # features from INPUT (Demucs)
    mel_in = _logmel_1xTprime(x_in)                                  # [T', 64]
    f0_in  = _f0_from_torchcrepe_tf(x_in, fmin=90.0, fmax=600.0)

    # features from TARGET (clean)  — cached once, reused every epoch
    mel_gt = _logmel_1xTprime(y_tgt)
    f0_gt  = _f0_from_torchcrepe_tf(y_tgt, fmin=90.0, fmax=600.0)

    cond = {
        "x_in":        x_in,
        "f0_in":       f0_in,         # torchcrepe once
        # "f0_gt":     (omit for speed)
        "loudness_db": ld_db,
        "mel_in":      mel_in,        # cheap
        "mel_gt":      mel_gt,        # cheap
        "track":       ex["meta/track"],
    }
    return cond, y_tgt

def make_ds(files, compression="", cache_path=None):
    ds = tf.data.TFRecordDataset(files, compression_type=compression)
    ds = ds.map(_parse_and_features, num_parallel_calls=1, deterministic=True)
    ds = ds.cache(str(cache_path) if cache_path else None)
    ds = ds.batch(1, drop_remainder=True).prefetch(1)
    return ds

# Build sets
# train_ds = make_ds(train_files,  compression=compression)
# val_ds   = make_ds(val_files,    compression=compression)


In [None]:
#@title Model
# --- hard reset setup ---
import os, tensorflow as tf, tensorflow.keras as keras, ddsp
from tensorflow.keras import mixed_precision

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"   # no XLA while stabilizing
mixed_precision.set_global_policy("float32")         # keep dtypes simple

for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass
tf.config.optimizer.set_jit(False)

# --- tiny harmonic-only model (same as before) ---
from ddsp.synths import Harmonic, FilteredNoise
from ddsp.effects import Reverb

N_HARMONICS   = 64
N_NOISE_BANDS = 65
MEL_BINS      = 64
LOUD_GATE_DB  = -60.0
MAX_HZ = 6000.0

class DDSPDecoder(keras.Model):
    def __init__(self, rnn_units=256, mlp_units=(256,128),
                 tf_decay_steps=5000,
                 tau=0.5,                 # ↓ lower temperature
                 sine_warmup_steps=300,   # shorter warmup
                 **kw):
        super().__init__(**kw)
        self.tau = float(tau)
        self.sine_warmup_steps = int(sine_warmup_steps)
        self.tf_decay_steps = int(tf_decay_steps)
        self.global_step = tf.Variable(0, trainable=False, dtype=tf.int64)




        # Timbre encoder from mel
        self.mel_proj = keras.layers.Dense(128, activation='relu')

        self.pre  = keras.layers.Dense(128, activation='relu')
        self.gru  = keras.layers.GRU(rnn_units, return_sequences=True)
        self.post = keras.Sequential([keras.layers.Dense(u, activation='relu') for u in mlp_units])

        # Bias harmonic logits with a steeper 1/h prior
        decay_bias = -0.30 * np.arange(N_HARMONICS, dtype=np.float32)  # <- steeper
        self.amp_head   = keras.layers.Dense(1)
        self.harm_head  = keras.layers.Dense(N_HARMONICS,
                              bias_initializer=keras.initializers.Constant(decay_bias))
        self.noise_head = keras.layers.Dense(N_NOISE_BANDS)

        # NEW: learnable harmonic roll-off factor α(t) ≥ 0
        self.roll_head  = keras.layers.Dense(1)  # predicts alpha(t)
        self.k_idx = tf.cast(tf.range(1, N_HARMONICS+1)[tf.newaxis, tf.newaxis, :], tf.float32)

        self.harm  = ddsp.synths.Harmonic(sample_rate=SR, amp_resample_method='linear')
        self.noise = ddsp.synths.FilteredNoise(n_samples=AUDIO_SAMPLES,
                                               scale_fn=ddsp.core.exp_sigmoid,
                                               initial_bias=-12.0)
        self.reverb = ddsp.effects.Reverb(trainable=True)

        # fixed 1/h tilt prior (broadcasted [1,1,H])
        self.tilt_log_b = tf.math.log(1.0 / tf.cast(tf.range(1, N_HARMONICS+1), tf.float32))[tf.newaxis, tf.newaxis, :]

        # Start dry near zero (don’t just copy Demucs)
        self.dry_logit = tf.Variable(-6.0, trainable=True)  # sigmoid≈0.0025

    def call(self, inputs, training=False):

        training_b = tf.cast(tf.convert_to_tensor(training) if not isinstance(training, bool)
                     else tf.constant(training), tf.bool)

        use_gt = tf.logical_and(training_b,
                        tf.less(self.global_step, tf.cast(self.tf_decay_steps, tf.int64)))

        # f0: ALWAYS from input (no torchcrepe on target)
        if "f0_in" in inputs:
            f0_hz = tf.cast(inputs["f0_in"], tf.float32)
        else:
            f0_hz = tf.cast(inputs["f0_hz"], tf.float32)  # fallback

        # mel: teacher forced (GT early, then Demucs)
        if ("mel_gt" in inputs) and ("mel_in" in inputs):
            mel = tf.cond(use_gt,
                          lambda: tf.cast(inputs["mel_gt"], tf.float32),
                          lambda: tf.cast(inputs["mel_in"], tf.float32))
        else:
            mel = tf.cast(inputs["mel"], tf.float32)      # fallback

        ld_db = tf.cast(inputs["loudness_db"], tf.float32)
        x_in  = tf.cast(inputs["x_in"], tf.float32)

        m = self.mel_proj(mel)                               # [B,T’,128]
        scalars = tf.stack([f0_hz, ld_db], axis=-1)          # [B,T’,2]
        x = tf.concat([scalars, m], axis=-1)                 # [B,T’,130]
        x = self.pre(x); x = self.gru(x); x = self.post(x)

        amp_raw     = ddsp.core.exp_sigmoid(self.amp_head(x))            # [B,T’,1]
        harm_logits = self.harm_head(x)                                   # [B,T’,H]
        alpha       = tf.nn.softplus(self.roll_head(x)) + 0.05            # α(t) ≥ 0.05
        roll        = tf.exp(-alpha * self.k_idx)                         # [B,T’,H], decays with k
        harm_pre    = harm_logits / self.tau + self.tilt_log_b + tf.math.log(roll + 1e-8)
        harm_dist   = tf.nn.softmax(harm_pre, axis=-1)                    # [B,T’,H]


        # cap by Nyquist for each frame so high harmonics vanish when f0 is high
        Hmax = tf.cast(tf.floor((SR * 0.5) / tf.maximum(f0_hz, 1e-6)), tf.int32)  # [B,T']
        mask = tf.sequence_mask(Hmax, maxlen=N_HARMONICS, dtype=tf.float32)       # [B,T',H]
        harm_dist = harm_dist * mask
        harm_dist = harm_dist / (tf.reduce_sum(harm_dist, axis=-1, keepdims=True) + 1e-8)



        noise_mag   = ddsp.core.exp_sigmoid(self.noise_head(x))           # [B,T’,B]

        ld_amp  = ddsp.core.db_to_amplitude(ld_db)                        # [B,T’]
        voicing = tf.cast(f0_hz > 0.0, tf.float32)
        voicing = tf.maximum(voicing, tf.cast(ld_db > LOUD_GATE_DB, tf.float32))

        amp = (amp_raw + 1e-3) * (ld_amp[..., None] + 1e-4) * voicing[..., None]
        noise_mag = (noise_mag + 1e-3) * tf.maximum(ld_amp, 1e-4)[..., None]
        # even less noise on voiced frames
        noise_mag *= (0.95*(1.0 - voicing) + 0.15*voicing)[..., None]

        # short warm-up: fundamental only

        warmup_active = tf.less(self.global_step, tf.cast(self.sine_warmup_steps, tf.int64))
        do_warmup = tf.logical_and(training_b, warmup_active)
        def _fundamental_only():
            b = tf.shape(harm_dist)[0]; t = tf.shape(harm_dist)[1]
            return tf.one_hot(tf.zeros([b, t], tf.int32), N_HARMONICS)
        harm_dist = tf.cond(do_warmup, _fundamental_only, lambda: harm_dist)

        amp       = tf.ensure_shape(amp,       [None, TPRIME, 1])
        harm_dist = tf.ensure_shape(harm_dist, [None, TPRIME, N_HARMONICS])
        noise_mag = tf.ensure_shape(noise_mag, [None, TPRIME, N_NOISE_BANDS])

        f0_3d   = f0_hz[..., tf.newaxis]
        audio_h = self.harm(amplitudes=amp, harmonic_distribution=harm_dist, f0_hz=f0_3d)
        audio_n = self.noise(magnitudes=noise_mag)

        synth = self.reverb(audio_h + audio_n)
        dry_g = tf.nn.sigmoid(self.dry_logit)
        return dry_g * x_in + (1.0 - dry_g) * synth

from ddsp.losses import SpectralLoss
spec_loss = SpectralLoss(
    fft_sizes=(2048,1024,512,256,128,64),
    loss_type='L1',
    mag_weight=1.0,
    logmag_weight=1.0,
    delta_freq_weight=0.5,
    delta_time_weight=0.1,
)

def mel_spec(y, n_fft=1024, hop=None):
    if hop is None:
        hop = int(round(SR / FRAME_RATE))
    S = tf.abs(tf.signal.stft(y, n_fft, hop, n_fft, window_fn=tf.signal.hann_window, pad_end=True))
    mel_fb = tf.signal.linear_to_mel_weight_matrix(MEL_BINS, n_fft//2+1, SR, 50.0, SR*0.45)
    M = tf.matmul(S, mel_fb)  # [B, Tm, M]
    return tf.math.log(M + 1e-5)

def mel_l1(y_true, y_pred):
    Yt = mel_spec(y_true); Yp = mel_spec(y_pred)
    n = tf.minimum(tf.shape(Yt)[1], tf.shape(Yp)[1])
    return tf.reduce_mean(tf.abs(Yt[:, :n, :] - Yp[:, :n, :]))

def spec_centroid(y, n_fft=1024, hop=None):
    if hop is None:
        hop = int(round(SR / FRAME_RATE))
    S = tf.abs(tf.signal.stft(y, n_fft, hop, n_fft, window_fn=tf.signal.hann_window, pad_end=True))  # [B,T,F]
    freqs = tf.linspace(0.0, tf.cast(SR, tf.float32)/2.0, n_fft//2+1)                                # [F]
    num = tf.reduce_sum(S * freqs[tf.newaxis, tf.newaxis, :], axis=-1)                               # [B,T]
    den = tf.reduce_sum(S + 1e-8, axis=-1)                                                           # [B,T]
    c   = num / (den + 1e-8)                                                                          # [B,T], Hz
    return tf.reduce_mean(c, axis=-1)                                                                 # [B]

def centroid_l1(y_true, y_pred):
    ct = spec_centroid(y_true); cp = spec_centroid(y_pred)
    return tf.reduce_mean(tf.abs(ct - cp)) / (SR/2.0)  # normalize to [0,1]

class DDSPTrainer(keras.Model):
    def __init__(self, net, mel_w=1.0, cent_w=0.05):
        super().__init__(); self.net = net
        self.mel_w, self.cent_w = mel_w, cent_w
        self.train_metric = keras.metrics.Mean(name="loss")
        self.val_metric   = keras.metrics.Mean(name="val_loss")

    @property
    def metrics(self): return [self.train_metric, self.val_metric]
    def build(self, _=None): self.built = True
    def compile(self, optimizer): super().compile(); self.optimizer = optimizer

    @tf.function
    def _loss(self, y_t, y_p):
        Ls = spec_loss(y_t, y_p)
        Lm = mel_l1(y_t, y_p)
        Lc = centroid_l1(y_t, y_p)
        return Ls + self.mel_w*Lm + self.cent_w*Lc

    @tf.function
    def train_step(self, data):
        cond, target = data
        with tf.GradientTape() as tape:
            pred = self.net(cond, training=True)
            n = tf.minimum(tf.shape(pred)[1], tf.shape(target)[1])
            loss = self._loss(tf.cast(target[:, :n], tf.float32), tf.cast(pred[:, :n], tf.float32))
        grads = tape.gradient(loss, self.net.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.net.trainable_variables))
        if hasattr(self.net, "global_step"): self.net.global_step.assign_add(1)
        self.train_metric.update_state(loss)
        return {"loss": self.train_metric.result()}

    @tf.function
    def test_step(self, data):
        cond, target = data
        pred = self.net(cond, training=False)
        n = tf.minimum(tf.shape(pred)[1], tf.shape(target)[1])
        loss = self._loss(tf.cast(target[:, :n], tf.float32), tf.cast(pred[:, :n], tf.float32))
        self.val_metric.update_state(loss)
        return {"loss": self.val_metric.result()}

In [None]:
# ---- One-time setup (run once) ----
import numpy as np, torch, torchcrepe, tensorflow as tf
torch.set_num_threads(1)  # avoid CPU thread thrash

def _torchcrepe_f0_np(y_np, sr, hop_length, fmin, fmax, periodicity_thresh, model_size):
    # y_np: 1D float32
    y_np = np.ascontiguousarray(y_np.astype(np.float32), dtype=np.float32)
    x = torch.from_numpy(y_np[None, None, :])  # [B=1, C=1, T], CPU
    with torch.no_grad():
        f0, pd = torchcrepe.predict(
            x, int(sr), int(hop_length),
            float(fmin), float(fmax),
            model=str(model_size),           # "tiny" or "full"
            batch_size=1024,
            device="cpu",                    # keep crepe on CPU
            return_periodicity=True,
        )
    f0 = f0.squeeze(0).squeeze(0).cpu().numpy()
    pd = pd.squeeze(0).squeeze(0).cpu().numpy()
    f0[pd < float(periodicity_thresh)] = 0.0  # simple V/UV mask
    return f0.astype(np.float32, copy=False)

def _f0_from_torchcrepe_tf(
    y_1d,
    fmin=90.0,
    fmax=600.0,
    hop_length=512,
    periodicity_thresh=0.40,
    model_size="tiny",
    # SR and TPRIME should be defined globally in your notebook
):
    f0 = tf.numpy_function(
        _torchcrepe_f0_np,
        [y_1d,
         tf.constant(SR, tf.int32),
         tf.constant(hop_length, tf.int32),
         tf.constant(fmin, tf.float32),
         tf.constant(fmax, tf.float32),
         tf.constant(periodicity_thresh, tf.float32),
         tf.constant(model_size, tf.string)],
        Tout=tf.float32,
    )
    f0 = tf.ensure_shape(tf.reshape(f0, [-1]), [None])            # [Tf0]
    f0 = ddsp.core.resample(f0[tf.newaxis, :], TPRIME)            # [1, T']
    return tf.ensure_shape(tf.squeeze(f0, 0), [TPRIME])           # [T']


In [None]:
from pathlib import Path

def make_cached_slice(files, K=400, compression="", cache_path=None):
    ds = tf.data.TFRecordDataset(files, compression_type=compression)
    ds = ds.map(_parse_and_features, num_parallel_calls=1, deterministic=True)
    ds = ds.take(K)
    ds = ds.cache(str(cache_path) if cache_path else None)   # RAM/disk
    ds = ds.shuffle(K, reshuffle_each_iteration=True).repeat()
    ds = ds.batch(1, drop_remainder=True).prefetch(1)
    return ds

# Local (fast) cache paths — avoid Google Drive paths here
train_run = make_cached_slice(train_files, K=400, compression=compression, cache_path="/content/cache/train_slice")
val_run   = make_cached_slice(val_files,   K=120, compression=compression, cache_path="/content/cache/val_slice")


In [None]:
# Run once
import numpy as np, torch, torchcrepe

torch.set_num_threads(1)  # avoid CPU thread thrash

def _torchcrepe_f0_np(y_np, sr, hop_length, fmin, fmax, periodicity_thresh, model_size):
    # Flatten to 1D, contiguous, float32
    y_np = np.asarray(y_np, dtype=np.float32).reshape(-1)
    y_np = np.ascontiguousarray(y_np)
    # torchcrepe wants [B, 1, T] on CPU
    x = torch.from_numpy(y_np[None, None, :])  # [1,1,T]
    with torch.no_grad():
        f0, pd = torchcrepe.predict(
            x, int(sr), int(hop_length),
            float(fmin), float(fmax),
            model=str(model_size),          # "tiny" or "full"
            batch_size=1024,
            device="cpu",
            return_periodicity=True,
        )
    f0 = f0.squeeze().cpu().numpy()
    pd = pd.squeeze().cpu().numpy()
    f0[pd < float(periodicity_thresh)] = 0.0
    return f0.astype(np.float32, copy=False)


In [None]:
import tensorflow as tf, ddsp, numpy as np

# === utils ===
def _to_Tprime(vec_1d, TPRIME):
    """vec_1d: [T*] → [TPRIME] via linear resample."""
    vec_1d = tf.reshape(tf.cast(vec_1d, tf.float32), [-1])
    out = ddsp.core.resample(vec_1d[tf.newaxis, :], int(TPRIME))  # [1, TPRIME]
    return tf.squeeze(out, 0)

def _pad_crop_1d(x, length):
    x = tf.reshape(x, [-1])
    x = x[:length]
    pad = tf.maximum(0, length - tf.shape(x)[0])
    x = tf.pad(x, [[0, pad]])
    x.set_shape([length])
    return x

def _resample_to_sr(x, sr_in, sr_out):
    x = tf.reshape(x, [-1])
    if sr_in == sr_out:
        return x
    ratio = tf.cast(sr_out, tf.float32) / tf.cast(sr_in, tf.float32)
    new_len = tf.cast(tf.round(tf.cast(tf.shape(x)[0], tf.float32) * ratio), tf.int32)
    return tf.reshape(ddsp.core.resample(x[tf.newaxis, :], new_len)[0], [-1])

def _rms_loudness_db(x, TPRIME, eps=1e-8):
    # rough per-frame loudness, then resample to TPRIME
    x = tf.reshape(x, [1, -1])
    # pick a hop that gives ~TPRIME frames
    hop = tf.maximum(1, tf.shape(x)[1] // int(TPRIME))
    win = tf.maximum(256, hop * 2)
    stft = tf.signal.stft(x, frame_length=win, frame_step=hop,
                          window_fn=tf.signal.hann_window, pad_end=True)
    mag = tf.abs(stft)  # [1, T*, F]
    rms = tf.sqrt(tf.reduce_mean(tf.square(mag), axis=-1) + eps)  # [1, T*]
    ld  = 20.0 * tf.math.log(rms + eps) / tf.math.log(10.0)
    ld  = tf.squeeze(ld, 0)              # [T*]
    return _to_Tprime(ld, TPRIME)        # [TPRIME]

def _logmel_1xTprime(x, SR, TPRIME, n_fft=1024, mel_bins=64):
    x = tf.reshape(x, [1, -1])
    hop = tf.maximum(1, tf.shape(x)[1] // int(TPRIME))
    S = tf.abs(tf.signal.stft(x, frame_length=n_fft, frame_step=hop,
                              window_fn=tf.signal.hann_window, pad_end=True))  # [1,T*,F]
    mel_fb = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=mel_bins,
        num_spectrogram_bins=n_fft//2 + 1,
        sample_rate=int(SR),
        lower_edge_hertz=50.0,
        upper_edge_hertz=float(SR) * 0.45,
    )
    M = tf.matmul(S, mel_fb)            # [1,T*,mel]
    M = tf.math.log(M + 1e-5)
    M = tf.squeeze(M, 0)                # [T*, mel]
    # resample along time to TPRIME
    M = tf.transpose(M, [1, 0])         # [mel, T*]
    M = ddsp.core.resample(M[tf.newaxis, ...], int(TPRIME))  # [1, mel, TPRIME]
    M = tf.squeeze(M, 0)                # [mel, TPRIME]
    return tf.transpose(M, [1, 0])      # [TPRIME, mel]


In [None]:
# Set these to your globals
# SR = 22050
# AUDIO_SAMPLES = 64000
# TPRIME = 1000   # frames per 4 s at 250 Hz, etc.

def _parse_no_f0(serialized):
    feats = {
        "audio/inputs":      tf.io.FixedLenFeature([], tf.string),
        "audio/targets":     tf.io.FixedLenFeature([], tf.string),
        "audio/sample_rate": tf.io.FixedLenFeature([], tf.int64),
        "meta/track":        tf.io.FixedLenFeature([], tf.string),
    }
    ex  = tf.io.parse_single_example(serialized, feats)
    xin = tf.io.decode_raw(ex["audio/inputs"], tf.float32)
    ygt = tf.io.decode_raw(ex["audio/targets"], tf.float32)
    sr  = tf.cast(ex["audio/sample_rate"], tf.int32)

    xin = _pad_crop_1d(_resample_to_sr(xin, sr, SR), AUDIO_SAMPLES)
    ygt = _pad_crop_1d(_resample_to_sr(ygt, sr, SR), AUDIO_SAMPLES)

    ld      = _rms_loudness_db(xin, TPRIME)                       # [TPRIME]
    mel_in  = _logmel_1xTprime(xin, SR, TPRIME, n_fft=1024)       # [TPRIME, 64]
    mel_gt  = _logmel_1xTprime(ygt, SR, TPRIME, n_fft=1024)

    cond = {
        "x_in":        xin,
        "f0_in":       tf.zeros([TPRIME], tf.float32),  # stub: no crepe here
        "loudness_db": ld,
        "mel_in":      mel_in,
        "mel_gt":      mel_gt,
        "track":       ex["meta/track"],
    }
    return cond, ygt


In [None]:
import time, tensorflow as tf
t0 = time.time()
_ = next(iter(tf.data.TFRecordDataset(train_files).map(_parse_no_f0).take(1)))
print("parse(no f0):", time.time() - t0, "s")



parse(no f0): 1.0155503749847412 s


In [None]:
import time

# A) TFRecord I/O only
t0=time.time(); _ = next(iter(tf.data.TFRecordDataset(train_files).take(1))); print("I/O:", time.time()-t0, "s")

I/O: 0.027957677841186523 s


In [None]:
# B) Parse WITHOUT f0 (temporarily stub f0 to zeros to test)
def _parse_no_f0(rec):
    cond, y = _parse_and_features(rec)  # if your function fuses both, duplicate and remove the f0 call
    cond["f0_in"] = tf.zeros([TPRIME], tf.float32)  # stub
    return cond, y
t0=time.time(); _ = next(iter(tf.data.TFRecordDataset(train_files).map(_parse_no_f0).take(1))); print("parse(no f0):", time.time()-t0, "s")

UnknownError: {{function_node __wrapped__IteratorGetNext_output_types_7_device_/job:localhost/replica:0/task:0/device:CPU:0}} Error in user-defined function passed to MapDataset:26 transformation with iterator: Iterator::Root::Prefetch::FiniteTake::Map: RuntimeError: Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: [1, 1, 1, 1, 65024]
Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/ops/script_ops.py", line 269, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "/usr/local/lib/python3.12/dist-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipython-input-1094770658.py", line 10, in _torchcrepe_f0_np
    f0, pd = torchcrepe.predict(
             ^^^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.12/dist-packages/torchcrepe/core.py", line 117, in predict
    for frames in generator:
                  ^^^^^^^^^

  File "/usr/local/lib/python3.12/dist-packages/torchcrepe/core.py", line 682, in preprocess
    frames = torch.nn.functional.unfold(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 5611, in unfold
    return torch._C._nn.im2col(
           ^^^^^^^^^^^^^^^^^^^^

RuntimeError: Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: [1, 1, 1, 1, 65024]


	 [[{{node PyFunc}}]] [Op:IteratorGetNext] name: 

In [None]:
#@title Building the trainer
model   = DDSPDecoder()
trainer = DDSPTrainer(model, mel_w=1.0, cent_w=0.05)
trainer.compile(optimizer=keras.optimizers.Adam(1e-3))

dummy = {
  "f0_hz":       tf.zeros([1, TPRIME], tf.float32),
  "loudness_db": tf.ones([1, TPRIME], tf.float32) * -40.0,
  "mel":         tf.zeros([1, TPRIME, MEL_BINS], tf.float32),
  "x_in":        tf.zeros([1, AUDIO_SAMPLES], tf.float32),
}
_ = model(dummy, training=False)
trainer.build(None)


# Callbacks
from pathlib import Path
CKPT_DIR = Path("/content/ckpt"); CKPT_DIR.mkdir(parents=True, exist_ok=True)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=str(CKPT_DIR / "ddsp.weights.h5"),
        save_weights_only=True,
        monitor="val_loss",    # now available
        mode="min",
        save_best_only=True,
    ),
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=5,
        restore_best_weights=True,
    ),
]

from tensorflow.data import experimental as tfd

def make_cached_slice(files, K=1000, compression="", cache_path=None):
    base = tf.data.TFRecordDataset(files, compression_type=compression)
    base = base.map(_parse_and_features, num_parallel_calls=1, deterministic=True)
    if cache_path is None:
        base = base.take(K).cache()                         # RAM
    else:
        base = base.take(K).cache(str(cache_path))          # disk
    # Important: shuffle *after* caching and use K as buffer to permute the slice
    base = base.shuffle(K, reshuffle_each_iteration=True)
    base = base.repeat()                                    # endless
    base = base.batch(1, drop_remainder=True).prefetch(1)
    return base

# Use the slice for speed while you iterate on the model:
TRAIN_STEPS = 400
VAL_STEPS   = 80

train_ds = make_ds(train_files, compression=compression, cache_path="/content/cache/train")
val_ds   = make_ds(val_files,   compression=compression, cache_path="/content/cache/val")
train_run = make_cached_slice(train_files, K=400, compression=compression, cache_path=None)
val_run   = make_cached_slice(val_files,   K=120, compression=compression, cache_path=None)

# train_run = make_cached_slice(train_files, K=400, compression=compression)
# val_run   = make_cached_slice(val_files,   K=120,  compression=compression)

In [None]:
%time _ = next(iter(train_run.take(1)))

In [None]:
#@title Fitting

history = trainer.fit(
    train_run,
    validation_data=val_run,
    epochs=3,
    steps_per_epoch=TRAIN_STEPS,
    validation_steps=VAL_STEPS,
    callbacks=callbacks,
    verbose=1,
)

Epoch 1/3


In [None]:
#@title Objective Evaluation
import numpy as np
from IPython.display import Audio, display

def si_sdr(ref, est):
    ref = ref - np.mean(ref); est = est - np.mean(est)
    a = np.dot(est, ref) / (np.dot(ref, ref) + 1e-12)
    e_true = a * ref; e_res = est - e_true
    return 10*np.log10((np.sum(e_true**2)+1e-12)/(np.sum(e_res**2)+1e-12))

def spec_conv(a, b, n_fft=1024, hop=256):
    A = np.abs(tf.signal.stft(a, n_fft, hop).numpy())
    B = np.abs(tf.signal.stft(b, n_fft, hop).numpy())
    return np.linalg.norm(A-B) / (np.linalg.norm(A)+1e-12)

sdrs, scs = [], []
for k, (cond, tgt) in enumerate(val_ds.take(16)):
    pred = model({
        "f0_hz":       cond["f0_hz"],
        "loudness_db": cond["loudness_db"],
        "mel":         cond["mel"],
        "x_in":        cond["x_in"],
    }, training=False)

    y = tgt[0].numpy().astype(np.float32)
    p = pred[0].numpy().astype(np.float32)
    n = min(len(y), len(p)); y, p = y[:n], p[:n]
    sdrs.append(si_sdr(y, p))
    scs.append(spec_conv(y, p))
print(f"Val SI-SDR median: {np.median(sdrs):.2f} dB")
print(f"Val SpectralConv median: {np.median(scs):.3f}")


Val SI-SDR median: -29.94 dB
Val SpectralConv median: 1.308


In [None]:
from IPython.display import Audio, display
import numpy as np

def play_idx(ds, idx, model, sr=SR):
    # get the idx-th item
    for i, (cond, tgt) in enumerate(ds.skip(idx).take(1)):
        pred = model({
          "f0_hz": cond["f0_hz"],
          "loudness_db": cond["loudness_db"],
          "mel": cond["mel"],
          "x_in": cond["x_in"],
        }, training=False)
        y = tgt[0].numpy().astype(np.float32)
        p = pred[0].numpy().astype(np.float32)
        display(Audio(y, rate=sr))
        display(Audio(p, rate=sr))
        break

play_idx(val_ds, idx=12, model=model)     # specific

NameError: name 'SR' is not defined