In [2]:
# Remove RAPIDS & TF families that cause dependency conflicts in Colab
!pip -q uninstall -y \
  cudf-cu12 cuml-cu12 cugraph-cu12 dask-cudf-cu12 dask-cuda rapids-dask-dependency \
  distributed-ucxx-cu12 ucx-py-cu12 rmm-cu12 cudf-cu12 cuvs-cu12 cupy-cuda12x raft-dask-cu12 nx-cugraph-cu12 \
  pylibcugraph-cu12 pylibraft-cu12 ucxx-cu12 \
  tensorflow tensorflow-text tensorflow-decision-forests tf-keras keras keras-hub tensorflow-hub \
  ddsp crepe || true

# Fresh packaging toolchain
!pip -q install -U pip setuptools wheel build jedi

[0m

In [1]:
# Fresh NumPy/SciPy
!pip -q install --no-cache-dir --force-reinstall "numpy==2.1.3" "scipy==1.14.1"

# numba compatible with NumPy 2.1.x (needed by librosa/resampy/etc.)
!pip -q install -U "numba==0.62.0"

# Audio + utils
!pip -q install -U librosa soundfile absl-py

# JAX + ecosystem
!pip -q install -U "jax[cpu]>=0.4.28" "flax>=0.8.2" "optax>=0.2.2" chex orbax-checkpoint gin-config clu

# TensorFlow stack that works on Py3.12 and satisfies DDSP/TFP imports
!pip -q install -U "tensorflow==2.20.0" "tf-keras==2.20.0" "tensorflow-probability==0.25.0"

# Silence CPU runtimes warning if present
!pip -q uninstall -y jax-cuda12-plugin -q || true

# Install DDSP source without its pinned (outdated) requirements
!pip -q install --no-deps "git+https://github.com/magenta/ddsp@main#egg=ddsp"

# f0 extractor (replacement for CREPE)
!pip -q install -U torch torchcrepe

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dopamine-rl 4.1.2 requires tensorflow>=2.2.0, which is not installed.
dopamine-rl 4.1.2 requires tf-keras>=2.18.0, which is not installed.[0m[31m
Reason for being yanked: <none given>[0m[33m
[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
[33m  DEPRECATION: Building 'ddsp' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'ddsp'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for ddsp (setup.py) ... [?25l[?25hdone


In [3]:
import os
os.environ.setdefault("TF_GPU_ALLOCATOR", "cuda_malloc_async")  # helps fragmentation

import tensorflow as tf
# Let TF allocate GPU memory gradually
for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print("Could not set memory growth:", e)

# Mixed precision halves activation memory on GPU
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

print("Policy:", mixed_precision.global_policy())

Policy: <DTypePolicy "mixed_float16">


In [4]:
# Provide a tiny 'crepe' module so ddsp.losses can import; we won't call it.
import sys, types
_crepe = types.ModuleType("crepe")
def _predict_stub(*args, **kwargs):
    raise RuntimeError("ddsp.losses tried to call crepe.predict; use torchcrepe for f0 instead.")
_crepe.predict = _predict_stub
sys.modules["crepe"] = _crepe

# Smoke test
import numpy as np, librosa, soundfile, torch, torchcrepe
import tensorflow_probability as tfp
import jax, flax, ddsp

print("numpy:", np.__version__)
print("TF / TFP:", tf.__version__, "/", tfp.__version__)
print("jax / flax:", jax.__version__, "/", flax.__version__)
print("ddsp:", ddsp.__version__)

numpy: 2.1.3
TF / TFP: 2.20.0 / 0.25.0
jax / flax: 0.7.2 / 0.12.0
ddsp: 3.7.0


In [11]:
#@title Mount & paths
from google.colab import drive
from pathlib import Path
import yaml, json, math, time

drive.mount('/content/drive', force_remount=True)

PROJECT_DIR = Path('/content/drive/MyDrive/ddsp-demucs')
CFG = yaml.safe_load(open(PROJECT_DIR / 'env' / 'config.yaml'))

TFRECORDS_DIR = Path(CFG['paths']['tfrecords_dir'])
EXP_DIR = Path(CFG['paths']['exp_dir']) / 'run_ddsp_001'
EXP_DIR.mkdir(parents=True, exist_ok=True)

print("TFRecords:", TFRECORDS_DIR)
print("Exp:", EXP_DIR)

Mounted at /content/drive
TFRecords: /content/drive/MyDrive/ddsp-demucs/data/tfrecords
Exp: /content/drive/MyDrive/ddsp-demucs/exp/run_ddsp_001


In [6]:
#@title TFRecord parsing and feature extraction (torchcrepe f0)
import ddsp
from ddsp.spectral_ops import compute_loudness
# --- single source of truth ---
SR = 16000
WIN_S = 1.0

FRAME_RATE = 250                  # <- your features‚Äô fps
HOP_SIZE = int(round(SR / FRAME_RATE))  # ‚âà88 samples
TPRIME = int(WIN_S * FRAME_RATE)  # 1.0s * 250 = 250
AUDIO_SAMPLES = int(WIN_S * SR)   # 1.0s * 16000 = 16000

feature_description = {
    "audio/inputs":  tf.io.FixedLenFeature([], tf.string),
    "audio/targets": tf.io.FixedLenFeature([], tf.string),
    "audio/sample_rate": tf.io.FixedLenFeature([], tf.int64),
    "audio/length":      tf.io.FixedLenFeature([], tf.int64),
    "meta/track":   tf.io.FixedLenFeature([], tf.string),
    "meta/subset":  tf.io.FixedLenFeature([], tf.string),
}

def _parse_ex(serialized):
    ex = tf.io.parse_single_example(serialized, feature_description)
    sr = tf.cast(ex["audio/sample_rate"], tf.int32)
    xin = tf.io.decode_raw(ex["audio/inputs"], tf.float32)
    xgt = tf.io.decode_raw(ex["audio/targets"], tf.float32)
    xin.set_shape([None]); xgt.set_shape([None])
    return xin, xgt, sr, ex["meta/track"], ex["meta/subset"]

# torchcrepe f0 wrapper
import torch
import torchcrepe
import numpy as np

def torchcrepe_f0(audio_1d: np.ndarray, sr: int, hop_length: int = HOP_SIZE,
                  fmin: float = 50., fmax: float = 1100.) -> np.ndarray:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x = torch.tensor(audio_1d, dtype=torch.float32, device=device)[None]  # [1, T]
    with torch.no_grad():
        f0 = torchcrepe.predict(
            x, sr, hop_length,
            torch.tensor([fmin], device=device),
            torch.tensor([fmax], device=device),
            model="full", batch_size=1024, device=device, return_periodicity=False
        )[0].cpu().numpy()  # [frames]
    # Replace NaNs/Infs
    f0 = np.where(np.isfinite(f0), f0, 0.0)
    return f0

@tf.function
def _pad_or_trim_1d(x, T):
    x = x[:T]
    paddings = [[0, tf.maximum(0, T - tf.shape(x)[0])]]
    return tf.pad(x, paddings)

def make_example(xin, xgt, sr, track, subset):
    # Fix length to window duration
    T = int(round(WIN_S * SR))
    xin = _pad_or_trim_1d(xin, T)
    xgt = _pad_or_trim_1d(xgt, T)

    # Compute f0 with torchcrepe (in numpy, then back to tf)
    xin_np = xin.numpy() if isinstance(xin, tf.Tensor) else xin
    f0_np = torchcrepe_f0(np.array(xin_np, dtype=np.float32), SR, HOP_SIZE)
    f0_hz = tf.convert_to_tensor(f0_np, dtype=tf.float32)

    # Loudness (A-weighted) using DDSP; expects [batch, time]
    x_b = tf.expand_dims(xin, 0)
    ld = compute_loudness(x_b, sample_rate=SR, frame_rate=FRAME_RATE, use_tf=True, ref_db=20.7)
    ld = tf.squeeze(ld, 0)

    cond = {"f0_hz": f0_hz, "loudness_db": ld}
    return cond, xgt

def dataset_from_dir(split, batch_size=8, shuffle=True):
    # inside dataset_from_dir(...)
    READ_PAR = 1                      # was AUTOTUNE; 1 is safer on RAM
    SHUFFLE_BUF = 512                 # was 4096
    MAP_PAR = 2                       # keep low, bump later if stable

    files = sorted((TFRECORDS_DIR / split).glob("*.tfrecord"))
    ds = tf.data.TFRecordDataset([str(p) for p in files],
                                num_parallel_reads=READ_PAR,
                                buffer_size=64*1024)  # default is 262,144; lower a bit

    if shuffle:
        ds = ds.shuffle(SHUFFLE_BUF, reshuffle_each_iteration=True)

    ds = ds.map(_parse_ex, num_parallel_calls=MAP_PAR, deterministic=False)


    # wrapper that returns only plain ndarrays, not dicts / tensors
    def _make_numpy(xin, xgt, sr, track, subset):
        # convert TF tensors to numpy for py_function
        xin_np    = np.array(xin)
        xgt_np    = np.array(xgt)
        sr_np     = np.array(sr).item()   # scalar
        track_np  = np.array(track).item()  # if scalar/string id
        subset_np = np.array(subset).item()

        # your current make_example(...) likely returns (cond_dict, target)
        # change it (or adapt here) to return (f0_hz, loudness_db, target)
        f0_hz, loud_db, target = make_example(xin_np, xgt_np, sr_np, track_np, subset_np)
        return (f0_hz.astype('float32'),
                loud_db.astype('float32'),
                target.astype('float32'))

    def _map_py(xin, xgt, sr, track, subset):
        f0_hz, loud_db, target = tf.py_function(
            func=_make_numpy,
            inp=[xin, xgt, sr, track, subset],
            Tout=(tf.float32, tf.float32, tf.float32)
        )
        # shapes are unknown after py_function; set them
        f0_hz.set_shape([None])
        loud_db.set_shape([None])
        target.set_shape([None])

        cond = {"f0_hz": f0_hz, "loudness_db": loud_db}
        return cond, target

    ds = ds.map(_map_py, num_parallel_calls=tf.data.AUTOTUNE)

    # If all examples have equal frame length, regular batch is fine.
    # If lengths vary, use padded_batch instead of batch.
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

BATCH = 8
train_ds = dataset_from_dir("train", batch_size=BATCH, shuffle=True)
val_ds   = dataset_from_dir("val",   batch_size=BATCH, shuffle=False)

In [8]:
def _pad_crop_time(x, length):
    # x: [B, T]  -> [B, length]
    t = tf.shape(x)[1]
    x = x[:, :length]
    pad = tf.maximum(0, length - tf.shape(x)[1])
    x = tf.pad(x, [[0, 0], [0, pad]])
    x = tf.ensure_shape(x, [None, length])
    return x

def _pad_crop_audio(y, length):
    # y: [B, N] -> [B, length]
    n = tf.shape(y)[1]
    y = y[:, :length]
    pad = tf.maximum(0, length - tf.shape(y)[1])
    y = tf.pad(y, [[0, 0], [0, pad]])
    y = tf.ensure_shape(y, [None, length])
    return y

def fix_batch_shapes(cond, target):
    cond = dict(cond)  # avoid side-effects
    cond['f0_hz']      = _pad_crop_time(cond['f0_hz'], TPRIME)
    cond['loudness_db'] = _pad_crop_time(cond['loudness_db'], TPRIME)
    target = _pad_crop_audio(target, AUDIO_SAMPLES)
    return cond, target


# Make sure train_ds and val_ds are already batched before this map.
train_ds = train_ds.map(fix_batch_shapes, num_parallel_calls=1).prefetch(1)
val_ds   = val_ds.map(fix_batch_shapes,   num_parallel_calls=1).prefetch(1)

In [10]:
def _pad_crop_time(x, length):
    # x: [B, T] -> [B, length]
    x = tf.convert_to_tensor(x)
    tf.debugging.assert_rank(x, 2, message="f0/loudness must be [B,T]")
    x = x[:, :length]
    pad = tf.maximum(0, length - tf.shape(x)[1])
    x = tf.pad(x, [[0, 0], [0, pad]])
    x = tf.ensure_shape(x, [None, length])
    return x

def _pad_crop_audio(y, length):
    # Accept [B, N] or [B, N, C]. If 3D, collapse channels to mono by mean.
    y = tf.convert_to_tensor(y)
    rank = tf.rank(y)

    def _from_2d():
        return y  # [B, N]

    def _from_3d():
        # [B, N, C] -> [B, N] (mono)
        return tf.reduce_mean(y, axis=-1)

    y = tf.cond(tf.equal(rank, 2), _from_2d, _from_3d)
    # Now y is [B, N]
    y = y[:, :length]
    pad = tf.maximum(0, length - tf.shape(y)[1])
    y = tf.pad(y, [[0, 0], [0, pad]])
    y = tf.ensure_shape(y, [None, length])
    return y

def fix_batch_shapes(cond, target):
    # Ensure dtypes and shapes after BATCHING.
    cond = {
        "f0_hz":      _pad_crop_time(tf.cast(cond["f0_hz"], tf.float32), TPRIME),
        "loudness_db": _pad_crop_time(tf.cast(cond["loudness_db"], tf.float32), TPRIME),
    }
    target = _pad_crop_audio(tf.cast(target, tf.float32), AUDIO_SAMPLES)
    return cond, target

# If you control the builder, also shrink shuffle/reads there.
# Regardless, clamp the ready-made datasets here:
clamp_shuffle = 256     # keep small
def clamp(ds, do_shuffle=False):
    if do_shuffle:
        ds = ds.shuffle(clamp_shuffle, reshuffle_each_iteration=True)
    ds = ds.map(fix_batch_shapes, num_parallel_calls=1, deterministic=False)
    ds = ds.prefetch(1)
    return ds

train_ds = clamp(train_ds, do_shuffle=True)
val_ds   = clamp(val_ds,   do_shuffle=False)

In [19]:
#@title Build harmonic+noise+reverb model (Keras)
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import mixed_precision


from ddsp.synths import Harmonic, FilteredNoise
from ddsp.effects import Reverb
from ddsp.losses import SpectralLoss

N_HARMONICS = 32
N_NOISE_BANDS = 33

class DDSPDecoder(keras.Model):
    def __init__(self, rnn_units=256, mlp_units=(256, 128), **kwargs):
        super().__init__(**kwargs)
        self.f0midi_range = (24.0, 84.0)  # clip to [C1, C6]

        self.pre  = keras.layers.Dense(128, activation='relu')
        self.gru  = keras.layers.GRU(rnn_units, return_sequences=True)
        self.post = keras.Sequential([keras.layers.Dense(u, activation='relu') for u in mlp_units])

        self.amp_head   = keras.layers.Dense(1)              # harmonic amplitude
        self.harm_head  = keras.layers.Dense(N_HARMONICS)    # harmonic distribution
        self.noise_head = keras.layers.Dense(N_NOISE_BANDS)  # noise magnitudes

        # Effects and synths
        self.reverb = Reverb(trainable=True)
        self.harm = Harmonic(sample_rate=SR, amp_resample_method='linear')
        self.noise = FilteredNoise(
            n_samples=int(round(WIN_S * SR)),                 # fixed-length windows
            scale_fn=ddsp.core.exp_sigmoid, initial_bias=-5.0
        )

    def call(self, inputs, training=False):
      # --- unify dtypes for features ---
      # Use the layer's compute dtype (float16 if mixed precision is on)
      comp_dtype = self.compute_dtype or tf.float32

      # Cast inputs to comp_dtype for the stack / MLP
      f0_hz = tf.cast(inputs["f0_hz"], comp_dtype)         # [B, T]
      ld_db = tf.cast(inputs["loudness_db"], comp_dtype)   # [B, T]

      # Compute f0_midi in float32 for numerical stability, then cast back
      f0_midi32 = ddsp.core.hz_to_midi(tf.cast(f0_hz, tf.float32))
      f0_midi32 = tf.clip_by_value(f0_midi32, *self.f0midi_range)
      f0_midi = tf.cast(f0_midi32, comp_dtype)

      # Now both tensors have the same dtype -> stack is happy
      x = tf.stack([f0_midi, ld_db], axis=-1)   # [B, T, 2]
      x = self.pre(x)
      x = self.gru(x)
      x = self.post(x)

      # Heads run in comp_dtype; do DSP scalings in float32, then cast if needed
      amp_head  = tf.cast(self.amp_head(x),  tf.float32)
      harm_log  = tf.cast(self.harm_head(x), tf.float32)
      noise_head= tf.cast(self.noise_head(x),tf.float32)

      amp       = ddsp.core.exp_sigmoid(amp_head)               # [B, T, 1] float32
      harm_dist = tf.nn.softmax(harm_log, axis=-1)              # [B, T, H] float32
      noise_mag = ddsp.core.exp_sigmoid(noise_head)             # [B, T, BANDS] float32

      f0_hz_3d  = tf.cast(f0_hz, tf.float32)[..., tf.newaxis]   # [B, T, 1] float32

      # DDSP synths in float32
      audio_h = self.harm(amplitudes=amp, harmonic_distribution=harm_dist, f0_hz=f0_hz_3d)
      audio_n = self.noise(magnitudes=noise_mag)

      # Align lengths if needed
      min_len = tf.minimum(tf.shape(audio_h)[-1], tf.shape(audio_n)[-1])
      audio_h = audio_h[..., :min_len]
      audio_n = audio_n[..., :min_len]

      audio = audio_h + audio_n
      audio = self.reverb(audio)        # Reverb also in float32

      # If the layer is running under mixed precision and you want to return fp16:
      # return tf.cast(audio, comp_dtype)
      return audio                       # keep float32; cast in loss if needed


# Multi-scale spectral loss (linear mag)
loss_fn = SpectralLoss(
    fft_sizes=(512, 256),
    loss_type='L1',
    mag_weight=1.0,
    logmag_weight=0.0,
    delta_time_weight=0.0,
    delta_freq_weight=0.0,
    cumsum_freq_weight=0.0,
    loudness_weight=0.0,
    name='spectral_loss'
)

class DDSPTrainer(keras.Model):
    def __init__(self, ddsp_model, loss_fn):
        super().__init__()
        self.ddsp_model = ddsp_model
        self.loss_fn = loss_fn

    def compile(self, optimizer):
        super().compile()
        self.optimizer = optimizer
        # Detect mixed-precision optimizer once (used inside tf.function)
        self._uses_loss_scale = hasattr(self.optimizer, "get_scaled_loss")

    @tf.function
    def train_step(self, data):
        cond, target = data
        with tf.GradientTape() as tape:
            pred = self.ddsp_model(cond, training=True)
            # keep loss math in float32 for stability
            loss = self.loss_fn(tf.cast(target, tf.float32), tf.cast(pred, tf.float32))
            if self._uses_loss_scale:
                loss_to_minimize = self.optimizer.get_scaled_loss(loss)
            else:
                loss_to_minimize = loss

        grads = tape.gradient(loss_to_minimize, self.ddsp_model.trainable_variables)
        if self._uses_loss_scale:
            grads = self.optimizer.get_unscaled_gradients(grads)

        self.optimizer.apply_gradients(zip(grads, self.ddsp_model.trainable_variables))
        return {"loss": loss}

    @tf.function
    def test_step(self, data):
        cond, target = data
        pred  = self.ddsp_model(cond, training=False)
        loss  = self.loss_fn(tf.cast(target, tf.float32), tf.cast(pred, tf.float32))
        return {"val_loss": loss}

# Trainer + compile (unchanged)
model = DDSPDecoder()
# 3) Optimizer (works with or without mixed precision)

use_mp = mixed_precision.global_policy().name == "mixed_float16"
base_opt = keras.optimizers.Adam(1e-3)
opt = mixed_precision.LossScaleOptimizer(base_opt) if use_mp else base_opt

# 4) Trainer: pass the loss_fn
trainer = DDSPTrainer(model, loss_fn)
trainer.compile(optimizer=opt)

In [20]:
Tprime = 250
dummy = {
    "f0_hz": tf.zeros([1, Tprime], dtype=tf.float32),
    "loudness_db": tf.zeros([1, Tprime], dtype=tf.float32),
}
_ = model(dummy, training=False)  # should run without the broadcast error now
print("Trainable params:", sum(int(np.prod(v.shape)) for v in model.trainable_variables))


Trainable params: 452034


In [None]:
#@title Pre-train
import tensorflow as tf, tensorflow.keras as keras
LOG_DIR = EXP_DIR / "tb"
CKPT_DIR = EXP_DIR / "ckpt"
CKPT_DIR.mkdir(parents=True, exist_ok=True)


# Recreate model & trainer after changes (and after a fresh restart if needed)
tf.keras.backend.clear_session()
model = DDSPDecoder()
use_mp = mixed_precision.global_policy().name == "mixed_float16"
base_opt = keras.optimizers.Adam(1e-3)
opt = mixed_precision.LossScaleOptimizer(base_opt) if use_mp else base_opt

# 4) Trainer: pass the loss_fn
trainer = DDSPTrainer(model, loss_fn)
trainer.compile(optimizer=opt)

# Build variables with a dummy pass (TPRIME frames)
_ = model({"f0_hz": tf.zeros([1, TPRIME]), "loudness_db": tf.zeros([1, TPRIME])}, training=False)

# Minimal callbacks (disable TensorBoard for now to save memory)
callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5, restore_best_weights=True),
    # Save weights infrequently; Keras 3 requires .weights.h5
    keras.callbacks.ModelCheckpoint(filepath=str(CKPT_DIR / "ddsp.best.weights.h5"),
                                    save_weights_only=True, monitor="val_loss", mode="min", save_best_only=True)
]

# If you control batching, set a small batch size upstream. Otherwise, cap steps here:
EPOCHS = 3
history = trainer.fit(
    train_ds.take(50),
    validation_data=val_ds.take(10),
    epochs=2,
    steps_per_epoch=25,
    validation_steps=5,
    verbose=1,
)


Epoch 1/2


In [None]:
#@title Train!
LOG_DIR = EXP_DIR / "tb"
CKPT_DIR = EXP_DIR / "ckpt"
CKPT_DIR.mkdir(parents=True, exist_ok=True)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=str(CKPT_DIR / "ddsp.best.weights.h5"),  # <-- fix extension
        save_weights_only=True,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
    ),
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        mode="min",
        patience=8,
        restore_best_weights=True,
    ),
    keras.callbacks.TensorBoard(
        log_dir=str(LOG_DIR), write_graph=False, update_freq="epoch"
    ),
]
keras.callbacks.ModelCheckpoint(
    filepath=str(CKPT_DIR / "ddsp.{epoch:03d}-{val_loss:.4f}.weights.h5"),
    save_weights_only=True,
    monitor="val_loss",
    mode="min",
    save_best_only=True,
)
EPOCHS = 50
history = trainer.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
)


Epoch 1/50


## Minimal Model

In [7]:
#@title Safe Runtime Setup

# <<< RUN THIS AS YOUR FIRST CELL AFTER A RUNTIME RESTART >>>
import os
# Disable XLA (can trigger big graphs & crashes in some TF builds)
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
# Let TF grow GPU memory gradually (works only if set before runtime heavy use)
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
# Reduce TF log noise
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf
from tensorflow.keras import mixed_precision

# Make sure we're NOT in mixed precision while debugging stability
mixed_precision.set_global_policy("float32")
print("Compute policy:", mixed_precision.global_policy())

# Also disable XLA via API (belt & suspenders)
tf.config.optimizer.set_jit(False)

# Enable GPU memory growth
for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print("Could not set memory growth:", e)
print("GPUs:", tf.config.list_physical_devices("GPU"))


Compute policy: <DTypePolicy "float32">
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [8]:
#@title Frugal Dataset

# Single source of truth
SR = 16000
WIN_S = 0.5            # <<< half-second windows to minimize memory for the smoke test
FRAME_RATE = 250       # your features fps
TPRIME = int(WIN_S * FRAME_RATE)   # 125
AUDIO_SAMPLES = int(WIN_S * SR)    # 8000

import tensorflow as tf

def _pad_crop_time(x, length):
    x = tf.convert_to_tensor(x)
    tf.debugging.assert_rank(x, 2)
    x = x[:, :length]
    pad = tf.maximum(0, length - tf.shape(x)[1])
    x = tf.pad(x, [[0,0],[0,pad]])
    x = tf.ensure_shape(x, [None, length])
    return x

def _pad_crop_audio(y, length):
    y = tf.convert_to_tensor(y)
    y = tf.cond(tf.equal(tf.rank(y), 2), lambda: y, lambda: tf.reduce_mean(y, axis=-1))
    y = y[:, :length]
    pad = tf.maximum(0, length - tf.shape(y)[1])
    y = tf.pad(y, [[0,0],[0,pad]])
    y = tf.ensure_shape(y, [None, length])
    return y

def fix_batch_shapes(cond, target):
    cond = {
        "f0_hz":      _pad_crop_time(tf.cast(cond["f0_hz"], tf.float32), TPRIME),
        "loudness_db": _pad_crop_time(tf.cast(cond["loudness_db"], tf.float32), TPRIME),
    }
    target = _pad_crop_audio(tf.cast(target, tf.float32), AUDIO_SAMPLES)
    return cond, target

# Clamp an existing dataset post-batch: tiny shuffle/prefetch, single-threaded map
clamp_shuffle = 64
def clamp(ds, do_shuffle=False):
    if do_shuffle:
        ds = ds.shuffle(clamp_shuffle, reshuffle_each_iteration=True)
    ds = ds.map(fix_batch_shapes, num_parallel_calls=1, deterministic=False)
    ds = ds.prefetch(0)   # no prefetch for the smoke test
    return ds

# Force tiny batch size, even if upstream is bigger
def rebatch(ds, bs=1):
    return ds.unbatch().batch(bs, drop_remainder=True)

# Apply to your existing datasets:
train_ds = clamp(rebatch(train_ds, bs=1), do_shuffle=True).take(8)   # 8 tiny batches max
val_ds   = clamp(rebatch(val_ds,   bs=1), do_shuffle=False).take(2)


In [10]:
#@title Minimal Model

import ddsp
import tensorflow as tf
import tensorflow.keras as keras
from ddsp.synths import Harmonic

N_HARMONICS = 16   # tiny

class DDSPDecoder(keras.Model):
    def __init__(self, rnn_units=64, mlp_units=(64,), **kwargs):
        super().__init__(**kwargs)
        self.f0midi_range = (24.0, 84.0)
        self.pre  = keras.layers.Dense(32, activation='relu')
        self.gru  = keras.layers.GRU(rnn_units, return_sequences=True)
        self.post = keras.Sequential([keras.layers.Dense(u, activation='relu') for u in mlp_units])

        self.amp_head   = keras.layers.Dense(1)
        self.harm_head  = keras.layers.Dense(N_HARMONICS)

        self.harm = Harmonic(sample_rate=SR, amp_resample_method='linear')

    def call(self, inputs, training=False):
        # unify dtypes for safe stacking
        f0_hz = tf.cast(inputs["f0_hz"], tf.float32)
        ld_db = tf.cast(inputs["loudness_db"], tf.float32)

        f0_midi = ddsp.core.hz_to_midi(tf.clip_by_value(f0_hz, 1.0, 8000.0))
        f0_midi = tf.clip_by_value(f0_midi, *self.f0midi_range)

        x = tf.stack([f0_midi, ld_db], axis=-1)   # [B, T, 2]
        x = self.pre(x); x = self.gru(x); x = self.post(x)

        amp       = ddsp.core.exp_sigmoid(self.amp_head(x))          # [B,T,1]
        harm_dist = tf.nn.softmax(tf.cast(self.harm_head(x), tf.float32), axis=-1)  # [B,T,H]
        f0_hz_3d  = f0_hz[..., tf.newaxis]                           # [B,T,1]

        audio_h = self.harm(amplitudes=amp, harmonic_distribution=harm_dist, f0_hz=f0_hz_3d)
        return audio_h  # [B, N]


In [11]:
#@title Small Spectral Loss
from ddsp.losses import SpectralLoss

loss_fn = SpectralLoss(
    fft_sizes=(512, 256),   # small
    loss_type='L1',
    mag_weight=1.0,
    logmag_weight=0.0,
    delta_time_weight=0.0,
    delta_freq_weight=0.0,
    cumsum_freq_weight=0.0,
    loudness_weight=0.0,
)

class DDSPTrainer(keras.Model):
    def __init__(self, ddsp_model, loss_fn):
        super().__init__()
        self.ddsp_model = ddsp_model
        self.loss_fn = loss_fn

    # IMPORTANT: run eagerly to keep graphs small/stable for the smoke test
    def compile(self, optimizer):
        super().compile(run_eagerly=True)
        self.optimizer = optimizer

    def train_step(self, data):
        cond, target = data
        with tf.GradientTape() as tape:
            pred = self.ddsp_model(cond, training=True)
            loss = self.loss_fn(tf.cast(target, tf.float32), tf.cast(pred, tf.float32))
        grads = tape.gradient(loss, self.ddsp_model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.ddsp_model.trainable_variables))
        return {"loss": loss}

    def test_step(self, data):
        cond, target = data
        pred  = self.ddsp_model(cond, training=False)
        loss  = self.loss_fn(tf.cast(target, tf.float32), tf.cast(pred, tf.float32))
        return {"val_loss": loss}


In [None]:
#@title Build
tf.keras.backend.clear_session()
model = DDSPDecoder()
# build vars
_ = model({"f0_hz": tf.zeros([1, TPRIME], tf.float32),
           "loudness_db": tf.zeros([1, TPRIME], tf.float32)}, training=False)

print("Trainable params:", sum(int(tf.size(v)) for v in model.trainable_variables))

opt = keras.optimizers.Adam(1e-3, clipnorm=1.0)  # clip for stability
trainer = DDSPTrainer(model, loss_fn)
trainer.compile(optimizer=opt)

# No callbacks, no TensorBoard, tiny steps
history = trainer.fit(
    train_ds, validation_data=val_ds,
    epochs=1, steps_per_epoch=2, validation_steps=1, verbose=1
)


Trainable params: 24177


### Baseline

In [7]:
# --- hard reset setup ---
import os, tensorflow as tf, tensorflow.keras as keras, ddsp
from tensorflow.keras import mixed_precision

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"   # no XLA while stabilizing
mixed_precision.set_global_policy("float32")         # keep dtypes simple

for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass
tf.config.optimizer.set_jit(False)

SR, WIN_S, FRAME_RATE = 16000, 0.5, 250
TPRIME, AUDIO_SAMPLES = int(WIN_S*FRAME_RATE), int(WIN_S*SR)

# --- tiny harmonic-only model (same as before) ---
from ddsp.synths import Harmonic
N_HARMONICS = 8

class DDSPDecoder(keras.Model):
    def __init__(self):
        super().__init__()
        self.f0midi_range = (24.0, 84.0)
        self.pre  = keras.layers.Dense(16, activation='relu')
        self.gru  = keras.layers.GRU(32, return_sequences=True)
        self.post = keras.layers.Dense(16, activation='relu')
        self.amp_head  = keras.layers.Dense(1)
        self.harm_head = keras.layers.Dense(N_HARMONICS)
        try:
            self.harm = Harmonic(sample_rate=SR,
                                amp_resample_method='linear',
                                frame_rate=FRAME_RATE)        # preferred
        except TypeError:
            self.harm = Harmonic(sample_rate=SR,
                                amp_resample_method='linear')
            # fallback for older builds:
            self.harm.frame_rate = FRAME_RATE

    def call(self, inputs, training=False):
        f0_hz = tf.cast(inputs["f0_hz"], tf.float32)
        ld_db = tf.cast(inputs["loudness_db"], tf.float32)
        f0_midi = ddsp.core.hz_to_midi(tf.clip_by_value(f0_hz, 1.0, 8000.0))
        f0_midi = tf.clip_by_value(f0_midi, *self.f0midi_range)
        x = tf.stack([f0_midi, ld_db], axis=-1)
        x = self.pre(x); x = self.gru(x); x = self.post(x)
        amp = ddsp.core.exp_sigmoid(self.amp_head(x))
        harm_dist = tf.nn.softmax(tf.cast(self.harm_head(x), tf.float32), axis=-1)
        audio = self.harm(amplitudes=amp, harmonic_distribution=harm_dist, f0_hz=f0_hz[..., tf.newaxis])
        return audio  # [B, N]

# loss & trainer (eager to keep graphs small)
from ddsp.losses import SpectralLoss
loss_fn = SpectralLoss(fft_sizes=(512,256), loss_type='L1', mag_weight=1.0)

class DDSPTrainer(keras.Model):
    def __init__(self, ddsp_model, loss_fn):
        super().__init__()
        self.net = ddsp_model          # avoid naming collisions with keras.Model
        self.loss_fn = loss_fn

    def compile(self, optimizer):
        super().compile(run_eagerly=True)  # eager for stability in the smoke test
        self.optimizer = optimizer

    def train_step(self, data):
        cond, target = data
        with tf.GradientTape() as tape:
            pred = self.net(cond, training=True)
            # üîí ensure identical lengths for loss
            pred   = pred[:, :AUDIO_SAMPLES]
            target = tf.cast(target[:, :AUDIO_SAMPLES], tf.float32)
            loss = self.loss_fn(target, tf.cast(pred, tf.float32))
        grads = tape.gradient(loss, self.net.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.net.trainable_variables))
        return {"loss": loss}

    def test_step(self, data):
        cond, target = data
        pred = self.net(cond, training=False)
        pred   = pred[:, :AUDIO_SAMPLES]
        target = tf.cast(target[:, :AUDIO_SAMPLES], tf.float32)
        loss = self.loss_fn(target, tf.cast(pred, tf.float32))
        return {"val_loss": loss}

# --- in-memory toy dataset (no TFRecords at all) ---
import numpy as np
B = 1  # tiny batch
cond = {"f0_hz": tf.constant(np.zeros((B, TPRIME), np.float32)),
        "loudness_db": tf.constant(np.zeros((B, TPRIME), np.float32))}
target = tf.zeros([B, AUDIO_SAMPLES], tf.float32)

toy = tf.data.Dataset.from_tensor_slices((cond, target)).repeat().batch(1)

# build, print params, do 1-2 tiny steps
tf.keras.backend.clear_session()
model = DDSPDecoder()
_ = model({"f0_hz": tf.zeros([1, TPRIME]), "loudness_db": tf.zeros([1, TPRIME])}, training=False)
print("Trainable params:", sum(int(tf.size(v)) for v in model.trainable_variables))

trainer = DDSPTrainer(model, loss_fn)
trainer.compile(optimizer=keras.optimizers.Adam(1e-3))
history = trainer.fit(toy.take(2), validation_data=toy.take(1), epochs=1,
                      steps_per_epoch=2, validation_steps=1, verbose=1)


Trainable params: 5529
[1m2/2[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m1s[0m 366ms/step - loss: 0.0000e+00 - val_val_loss: 0.0000e+00


In [35]:
import tensorflow as tf, pathlib

# --- keep the same globals you used for the toy run ---
SR = 16000
WIN_S = 1                   # stay small for the smoke test
FRAME_RATE = 250
TPRIME = int(WIN_S * FRAME_RATE)   # 125
AUDIO_SAMPLES = int(WIN_S * SR)    # 8000

# Point to just a FEW files to start
#TFRECORDS_DIR = pathlib.Path("/path/to/tfrecords")     # <-- change this
train_files = sorted(map(str, (TFRECORDS_DIR / "train").glob("*.tfrecord")))
val_files   = sorted(map(str, (TFRECORDS_DIR / "val").glob("*.tfrecord")))

# Robust parser that handles variable-length sequences
def _parse_ex(serialized):
    spec = {
        "f0_hz":       tf.io.VarLenFeature(tf.float32),
        "loudness_db": tf.io.VarLenFeature(tf.float32),
        "audio":       tf.io.VarLenFeature(tf.float32),
    }
    ex = tf.io.parse_single_example(serialized, spec)
    f0 = tf.sparse.to_dense(ex["f0_hz"])
    ld = tf.sparse.to_dense(ex["loudness_db"])
    y  = tf.sparse.to_dense(ex["audio"])
    return {"f0_hz": f0, "loudness_db": ld}, y

# Crop/pad to fixed sizes (matches the toy run)
def _pad_crop_time(x, length):
    x = x[:length]
    pad = tf.maximum(0, length - tf.shape(x)[0])
    x = tf.pad(x, [[0, pad]])
    return tf.ensure_shape(x, [length])

def _pad_crop_audio(y, length):
    y = y[:length]
    pad = tf.maximum(0, length - tf.shape(y)[0])
    y = tf.pad(y, [[0, pad]])
    return tf.ensure_shape(y, [length])

def _fix_shapes(cond, y):
    cond = {
        "f0_hz":      tf.cast(_pad_crop_time(cond["f0_hz"], TPRIME), tf.float32),
        "loudness_db": tf.cast(_pad_crop_time(cond["loudness_db"], TPRIME), tf.float32),
    }
    y = tf.cast(_pad_crop_audio(y, AUDIO_SAMPLES), tf.float32)
    return cond, y

def make_ds(files):
    # ultra-conservative: single-threaded, tiny buffers, NO prefetch
    ds = tf.data.TFRecordDataset(
        files,
        num_parallel_reads=1,
        buffer_size=64*1024,          # 16KB ingest buffer
        compression_type=""           # set "GZIP" if your files are gzipped
    )
    ds = ds.map(_parse_ex, num_parallel_calls=1, deterministic=True)
    ds = ds.map(_fix_shapes, num_parallel_calls=1, deterministic=True)
    ds = ds.batch(1, drop_remainder=True).shuffle(128)   # tiny batch to minimize memory
    ds = ds.prefetch(1)                        # cap the dataset while testing
    return ds

train_ds_lean = make_ds(train_files)
val_ds_lean   = make_ds(val_files)


In [36]:
def peek(ds, n=1):
    it = iter(ds.take(n))
    cond, y = next(it)
    tf.print("peek shapes  f0:", tf.shape(cond["f0_hz"]),
             "ld:", tf.shape(cond["loudness_db"]),
             "y:", tf.shape(y))
peek(train_ds_lean, 1)

peek shapes  f0: [1 250] ld: [1 250] y: [1 16000]


In [37]:
# Reuse your already-built model / loss / trainer from the toy run:
# model, loss_fn, DDSPTrainer (eager), etc.

history = trainer.fit(
    train_ds_lean,
    validation_data=val_ds_lean,
    epochs=1,
    steps_per_epoch=2,
    validation_steps=1,
    verbose=1
)


[1m2/2[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m1s[0m 513ms/step - loss: 0.0000e+00 - val_val_loss: 0.0000e+00


In [38]:
TRAIN_STEPS = 200   # how many batches you want per epoch (tweak)
VAL_STEPS   = 40

from tensorflow.data import experimental as tfd

train_run = train_ds_lean.take(TRAIN_STEPS).apply(tfd.assert_cardinality(TRAIN_STEPS))
val_run   = val_ds_lean.take(VAL_STEPS).apply(tfd.assert_cardinality(VAL_STEPS))

history = trainer.fit(
    train_run,
    validation_data=val_run,
    epochs=1,
    steps_per_epoch=TRAIN_STEPS,
    validation_steps=VAL_STEPS,
    verbose=1,
)


[1m200/200[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m54s[0m 261ms/step - loss: 0.0000e+00 - val_val_loss: 0.0000e+00


In [12]:
from pathlib import Path

cand_patterns = ["*.tfrecord", "*.tfrecords", "*.tfrec", "*.tfrec", "*.tfr"]
train_dir = TFRECORDS_DIR / "train"
val_dir   = TFRECORDS_DIR / "val"

def list_files(p):
    out=[]
    for pat in cand_patterns:
        out += list(p.glob(pat))
    return sorted(out)

train_files = list_files(train_dir)
val_files   = list_files(val_dir)

print("#train files:", len(train_files))
print("#val files:", len(val_files))
print("first few train:", [str(p) for p in train_files[:3]])


#train files: 14
#val files: 4
first few train: ['/content/drive/MyDrive/ddsp-demucs/data/tfrecords/train/train-00000.tfrecord', '/content/drive/MyDrive/ddsp-demucs/data/tfrecords/train/train-00001.tfrecord', '/content/drive/MyDrive/ddsp-demucs/data/tfrecords/train/train-00002.tfrecord']
