# Train a DDSP Autoencoder on GPU

This notebook demonstrates how to install the DDSP library and train it for synthesis based on your own data using our command-line scripts. If run inside of Colab, it will automatically use a free Google Cloud GPU.

At the end, you'll have a custom-trained checkpoint that you can download to use with the [DDSP Timbre Transfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb).

<img src="https://storage.googleapis.com/ddsp/additive_diagram/ddsp_autoencoder.png" alt="DDSP Autoencoder figure" width="700">


#### Set your base directory
* In drive, put all of the audio (.wav, .mp3) files with which you would like to train in a single folder.
 * Typically works well with 10-20 minutes of audio from a single monophonic source (also, one acoustic environment).
* Use the file browser in the left panel to find a folder with your audio, right-click **"Copy Path", paste below**, and run the cell.

In [4]:
#@markdown (ex. `/content/drive/My Drive/...`) Leave blank to skip loading from Drive.
DRIVE_DIR = 'monoGuitarDataset/singlecoil' #@param {type: "string"}

import os
assert os.path.exists(DRIVE_DIR)
print('Drive Folder Exists:', DRIVE_DIR)


Drive Folder Exists: /content/drive/MyDrive/monoGuitarDataset/singlecoil


## Make directories to save model and data

In [5]:
AUDIO_DIR = 'data/audio'
AUDIO_FILEPATTERN = AUDIO_DIR + '/*'
#!mkdir -p $AUDIO_DIR

SAVE_DIR = os.path.join(DRIVE_DIR, 'ddsp-solo-instrument')
#!mkdir -p "$SAVE_DIR"

## Prepare Dataset


#### Upload training audio

Upload audio files to use for training your model. Uses `DRIVE_DIR` if connected to drive, otherwise prompts local upload.

In [6]:
import glob
import os
#from ddsp.colab import colab_utils

#mp3_files = glob.glob(os.path.join(DRIVE_DIR, '*.mp3'))
#wav_files = glob.glob(os.path.join(DRIVE_DIR, '*.wav'))
#audio_files = mp3_files + wav_files

#for fname in audio_files:
  #target_name = os.path.join(AUDIO_DIR, os.path.basename(fname).replace(' ', '_'))
  #print('Copying {} to {}'.format(fname, target_name))
  #!cp "$fname" $target_name


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G81-51111-1111-20916.wav to data/audio/G81-51111-1111-20916.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G83-68509-1111-21044.wav to data/audio/G83-68509-1111-21044.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G83-51301-1111-21010.wav to data/audio/G83-51301-1111-21010.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G81-62503-1111-20960.wav to data/audio/G81-62503-1111-20960.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G81-64600-1111-20970.wav to data/audio/G81-64600-1111-20970.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G81-66602-1111-20972.wav to data/audio/G81-66602-1111-20972.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G83-65601-1111-21049.wav to data/audio/G83-65601-1111-21049.wav
Copying /content/drive/MyDrive/monoGuitarDataset/singlecoil/G81-59500-1111-20957.wav to data/audio/G81-59500-1111-20957.wav
Copying 

### Preprocess raw audio into TFRecord dataset

We need to do some preprocessing on the raw audio you uploaded to get it into the correct format for training. This involves turning the full audio into short (4-second) examples, inferring the fundamental frequency (or "pitch") with [CREPE](http://github.com/marl/crepe), and computing the loudness. These features will then be stored in a sharded [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file for easier loading. Depending on the amount of input audio, this process usually takes a few minutes.

* (Optional) Transfer dataset from drive. If you've already created a dataset, from a previous run, this cell will skip the dataset creation step and copy the dataset from `$DRIVE_DIR/data`

In [None]:
import glob
import os
#from ddsp.training.data_preparation import prepare_tfrecord

TRAIN_TFRECORD = 'data/train.tfrecord'
TRAIN_TFRECORD_FILEPATTERN = TRAIN_TFRECORD + '*'
AUDIO_FILEPATTERN = 'data/audio/*'

# Copy dataset from drive if dataset has already been created.
#drive_data_dir = os.path.join(DRIVE_DIR, 'data/audio')
#drive_dataset_files = glob.glob(drive_data_dir + '/*')

#if DRIVE_DIR and len(drive_dataset_files) > 0:
#  !cp "$drive_data_dir"/* data/

#else:
#  # Make a new dataset.
#  if not glob.glob(AUDIO_FILEPATTERN):
#    raise ValueError('No audio files found. Please use the previous cell to upload.')

#audio_files = glob.glob(AUDIO_FILEPATTERN)
#if not audio_files:
#  raise ValueError('No audio files found in data/audio/. Please use the previous cell to upload.')

import sys
from absl import flags
from ddsp.training.data_preparation import ddsp_prepare_tfrecord as prep_script

#sys.argv = [
#    'ddsp_prepare_tfrecord',
#    '--input_audio_filepatterns=data/audio/*.wav',
#    '--output_tfrecord_path=data/train.tfrecord',
#    '--example_secs=1.5',
#    '--sample_rate=16000',
#    '--alsologtostderr'
#]

#app.run(prep_script.main)
#flags.FLAGS(sys.argv)  # Required for some environments
#flags.FLAGS.mark_as_parsed()
#prep_script.main([])
# Verify that TFRecord was created
#tfrecords_created = glob.glob(TRAIN_TFRECORD_FILEPATTERN)
#print("TFRecord files generated:", tfrecords_created)
#if not tfrecords_created:
#    raise RuntimeError("TFRecord generation failed. No files created at data/train.tfrecord*.")

# Copy dataset to drive for safe-keeping
#drive_tfrecord_dir = os.path.join(DRIVE_DIR, 'data')
#os.makedirs(drive_tfrecord_dir, exist_ok=True)

#print(f"Saving TFRecords to: {drive_tfrecord_dir}")
import shutil

#for tfrecord in tfrecords_created:
#    shutil.copy(tfrecord, drive_tfrecord_dir)


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

2025-05-10 02:08:26.663810: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I0510 02:08:27.133851 132362418708480 environments.py:376] Default Python SDK image for environment is apache/beam_python3.11_sdk:2.48.0
I0510 02:08:27.240300 132362418708480 statecache.py:234] Creating 

### Save dataset statistics for timbre transfer

Quantile normalization helps match loudness of timbre transfer inputs to the
loudness of the dataset, so let's calculate it here and save in a pickle file.

In [None]:
def safe_parse(example):
    try:
        return True
    except tf.errors.InvalidArgumentError:
        return False

def compute_dataset_statistics(data_provider, num_batches=100):
    """Compute mean/std stats from dataset, skipping corrupted records."""
    dataset = data_provider.get_dataset().shuffle(100)

    loudness = []
    f0_hz = []
    f0_confidence = []

    count = 0
    for ex in dataset:
        try:
            loudness.append(ex['loudness_db'].numpy())
            f0_hz.append(ex['f0_hz'].numpy())
            f0_confidence.append(ex['f0_confidence'].numpy())
            count += 1
        except tf.errors.InvalidArgumentError as e:
            print("Skipping example due to parse error:", e)
            continue

        if count >= num_batches:
            break

    stats = {
        'loudness_db': {
            'mean': np.mean(np.concatenate(loudness)),
            'std': np.std(np.concatenate(loudness)),
        },
        'f0_hz': {
            'mean': np.mean(np.concatenate(f0_hz)),
            'std': np.std(np.concatenate(f0_hz)),
        },
        'f0_confidence': {
            'mean': np.mean(np.concatenate(f0_confidence)),
            'std': np.std(np.concatenate(f0_confidence)),
        },
    }

    return stats


In [None]:
import ddsp.training
from ddsp.training.data import TFRecordProvider
#from ddsp.training.data_preparation import compute_dataset_statistics
#from ddsp.training.data_preparation.compute_statistics import compute_dataset_statistics
#from ddsp.training.data_preparation.compute_statistics import compute_dataset_statistics
#from ddsp.training.data_preparation import compute_dataset_statistics


import pickle
import os

data_provider = TFRecordProvider(TRAIN_TFRECORD_FILEPATTERN)
dataset = data_provider.get_dataset(shuffle=False)
PICKLE_FILE_PATH = os.path.join(SAVE_DIR, 'dataset_statistics.pkl')

# Compute stats
stats = compute_dataset_statistics(data_provider)

# Save to pickle file
with open(PICKLE_FILE_PATH, 'wb') as f:
    pickle.dump(stats, f)


Let's load the dataset in the `ddsp` library and have a look at one of the examples.

In [None]:
import ddsp.training
from ddsp.training.data import TFRecordProvider
from matplotlib import pyplot as plt
import numpy as np
import soundfile as sf
import os

# Setup data provider
print(TRAIN_TFRECORD_FILEPATTERN)
data_provider = TFRecordProvider(TRAIN_TFRECORD_FILEPATTERN)
dataset = data_provider.get_dataset(shuffle=False)

# Get one example
try:
    ex = next(iter(dataset))
except StopIteration:
    raise ValueError(
        'TFRecord contains no examples. Please try re-running the pipeline with different audio file(s).')

# === REPLACE colab_utils.specplot ===
plt.figure(figsize=(14, 4))
plt.specgram(ex['audio'].numpy(), Fs=16000)
plt.title('Spectrogram')
plt.xlabel('Time')
plt.ylabel('Frequency')
plt.show()

# === REPLACE colab_utils.play ===
# Save audio to file (so you can play it later on local machine)
output_audio_path = os.path.join(SAVE_DIR, 'example_audio.wav')
sf.write(output_audio_path, ex['audio'].numpy(), 16000)
print(f"✅ Audio saved to {output_audio_path}")

# === Plot loudness, F0, confidence ===
f, ax = plt.subplots(3, 1, figsize=(14, 4))
x = np.linspace(0, 4.0, 1000)
ax[0].set_ylabel('loudness_db')
ax[0].plot(x, ex['loudness_db'])
ax[1].set_ylabel('F0_Hz')
ax[1].set_xlabel('seconds')
ax[1].plot(x, ex['f0_hz'])
ax[2].set_ylabel('F0_confidence')
ax[2].set_xlabel('seconds')
ax[2].plot(x, ex['f0_confidence'])
plt.tight_layout()
plt.show()


## Train Model

We will now train a "solo instrument" model. This means the model is conditioned only on the fundamental frequency (f0) and loudness with no instrument ID or latent timbre feature. If you uploaded audio of multiple instruemnts, the neural network you train will attempt to model all timbres, but will likely associate certain timbres with different f0 and loudness conditions.

First, let's start up a [TensorBoard](https://www.tensorflow.org/tensorboard) to monitor our loss as training proceeds.

Initially, TensorBoard will report `No dashboards are active for the current data set.`, but once training begins, the dashboards should appear.

In [None]:
%reload_ext tensorboard
import tensorboard as tb
tb.notebook.start('--logdir "{}"'.format(SAVE_DIR))

### We will now begin training.

Note that we specify [gin configuration](https://github.com/google/gin-config) files for the both the model architecture ([solo_instrument.gin](TODO)) and the dataset ([tfrecord.gin](TODO)), which are both predefined in the library. You could also create your own. We then override some of the spefic params for `batch_size` (which is defined in in the model gin file) and the tfrecord path (which is defined in the dataset file).

### Training Notes:
* Models typically perform well when the loss drops to the range of ~4.5-5.0.
* Depending on the dataset this can take anywhere from 5k-30k training steps usually.
* The default is set to 30k, but you can stop training at any time, and for timbre transfer, it's best to stop before the loss drops too far below ~5.0 to avoid overfitting.
* On the colab GPU, this can take from around 3-20 hours.
* We **highly recommend** saving checkpoints directly to your drive account as colab will restart naturally after about 12 hours and you may lose all of your checkpoints.
* By default, checkpoints will be saved every 300 steps with a maximum of 10 checkpoints (at ~60MB/checkpoint this is ~600MB). Feel free to adjust these numbers depending on the frequency of saves you would like and space on your drive.
* If you're restarting a session and `DRIVE_DIR` points a directory that was previously used for training, training should resume at the last checkpoint.

In [None]:
!ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="$SAVE_DIR" \
  --gin_file=models/solo_instrument.gin \
  --gin_file=datasets/tfrecord.gin \
  --gin_param="TFRecordProvider.file_pattern='$TRAIN_TFRECORD_FILEPATTERN'" \
  --gin_param="batch_size=16" \
  --gin_param="train_util.train.num_steps=30000" \
  --gin_param="train_util.train.steps_per_save=300" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=10"

## Resynthesis

Check how well the model reconstructs the training data

In [None]:
import ddsp.training
import gin
from matplotlib import pyplot as plt
import numpy as np
import soundfile as sf
import os

# Setup data provider
data_provider = ddsp.training.data.TFRecordProvider(TRAIN_TFRECORD_FILEPATTERN)
dataset = data_provider.get_batch(batch_size=1, shuffle=False)

# Get one batch
try:
    batch = next(iter(dataset))
except StopIteration:
    raise ValueError(
        'TFRecord contains no examples. Please try re-running the pipeline with '
        'different audio file(s).')

# Parse the gin config
gin_file = os.path.join(SAVE_DIR, 'operative_config-0.gin')
gin.parse_config_file(gin_file)

# Load model
model = ddsp.training.models.Autoencoder()
model.restore(SAVE_DIR)

# Resynthesize audio
outputs = model(batch, training=False)
audio_gen = model.get_audio_from_outputs(outputs)
audio = batch['audio']

# === REPLACE specplot ===
def plot_specgram(waveform, sample_rate=16000, title='Spectrogram'):
    plt.figure(figsize=(14, 4))
    plt.specgram(waveform.numpy()[0], Fs=sample_rate)
    plt.title(title)
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    plt.show()

# === REPLACE play ===
def save_audio(waveform, path, sample_rate=16000):
    sf.write(path, waveform.numpy()[0], sample_rate)
    print(f"✅ Saved audio to {path}")

# Plot and save original audio
print('Original Audio')
plot_specgram(audio)
save_audio(audio, os.path.join(SAVE_DIR, 'original_audio.wav'))

# Plot and save resynthesized audio
print('Resynthesis')
plot_specgram(audio_gen)
save_audio(audio_gen, os.path.join(SAVE_DIR, 'resynthesis_audio.wav'))


## Download Checkpoint

Below you can download the final checkpoint. You are now ready to use it in the [DDSP Timbre Tranfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb).

In [None]:
import tensorflow as tf
import os
import subprocess
import shutil

# Define filenames
CHECKPOINT_ZIP = 'my_solo_instrument.zip'

# Find latest checkpoint
latest_checkpoint_path = tf.train.latest_checkpoint(SAVE_DIR)
if latest_checkpoint_path is None:
    raise ValueError("No checkpoint found in SAVE_DIR")

latest_checkpoint_fname = os.path.basename(latest_checkpoint_path)

# Create the zip file
zip_command = (
    f'cd "{SAVE_DIR}" && zip {CHECKPOINT_ZIP} {latest_checkpoint_fname}* '
    'operative_config-0.gin dataset_statistics.pkl'
)
subprocess.run(zip_command, shell=True, check=True)

# Copy zip to current directory
src = os.path.join(SAVE_DIR, CHECKPOINT_ZIP)
dst = os.path.join('./', CHECKPOINT_ZIP)
shutil.copy(src, dst)

print(f"✅ Checkpoint ZIP saved as {dst}")

# Print instructions for manual download
print("\n📦 To download the zip file to your local machine, run this from your laptop:")
print(f"scp username@cluster:{os.getcwd()}/{CHECKPOINT_ZIP} .")
