# Train a DDSP Autoencoder on GPU

This notebook demonstrates how to install the DDSP library and train it for synthesis based on your own data using our command-line scripts. If run inside of Colab, it will automatically use a free Google Cloud GPU.

At the end, you'll have a custom-trained checkpoint that you can download to use with the [DDSP Timbre Transfer Colab](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb).

<img src="https://storage.googleapis.com/ddsp/additive_diagram/ddsp_autoencoder.png" alt="DDSP Autoencoder figure" width="700">


## Make directories to save model and data

In [None]:
from pathlib import Path
import os

PROJECT_DIR = Path('/home/luca/Development')

DATA_DIR = PROJECT_DIR.joinpath('data_l2_fixed_3')
DATA_DIR.mkdir(exist_ok=True)

DATASET_DIR = PROJECT_DIR.joinpath('dataset')

DATASET = 'chitarra_michele_rossi/unite'

AUDIO_DIR = DATASET_DIR.joinpath(DATASET)

TFRECORDS_DIR = PROJECT_DIR.joinpath('TFRecords').joinpath(DATASET)
TFRECORDS_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_TFRECORD_FILEPATTERN = TFRECORDS_DIR.joinpath('train.tfrecord-train*')
TEST_TFRECORD_FILEPATTERN = TFRECORDS_DIR.joinpath('train.tfrecord-eval*')
os.environ['LD_LIBRARY_PATH'] = os.environ['CONDA_PREFIX'] + '/lib'
# os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=' + os.environ['CONDA_PREFIX'] + '/lib'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/home/luca/miniconda3/envs/tf'
print(os.environ)
!printenv

## Enable memory growth
This prevents tensorflow to allocate all the available memory as soon as it starts.

In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:

  try:
    for gpu in physical_devices:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
      # Memory growth must be set before GPUs have been initialized.
      print(e)

### Preprocess raw audio into TFRecord dataset

We need to do some preprocessing on the raw audio you uploaded to get it into the correct format for training. This involves turning the full audio into short (4-second) examples, inferring the fundamental frequency (or "pitch") with [CREPE](http://github.com/marl/crepe), and computing the loudness. These features will then be stored in a sharded [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file for easier loading. Depending on the amount of input audio, this process usually takes a few minutes.

* (Optional) Transfer dataset from drive. If you've already created a dataset, from a previous run, this cell will skip the dataset creation step and copy the dataset from `$DRIVE_DIR/data` 

In [None]:
from pathlib import Path

dataset_files = list(TFRECORDS_DIR.glob('*'))
if len(dataset_files) == 0:
    if not AUDIO_DIR.glob('*'):
        raise ValueError('No audio files found in {}'.format(AUDIO_DIR))

    AUDIO_FILEPATTERN = str(AUDIO_DIR.joinpath('*les*'))
    TRAIN_TFRECORD = str(TFRECORDS_DIR.joinpath('train.tfrecord'))

    !ddsp_prepare_tfrecord \
        --input_audio_filepatterns="$AUDIO_FILEPATTERN" \
        --output_tfrecord_path="$TRAIN_TFRECORD" \
        --num_shards=10 \
        --example_secs=10 \
        --hop_secs=1 \
        --eval_split_fraction=0.2 \
        --alsologtostderr    

    

### Save dataset statistics for timbre transfer

Quantile normalization helps match loudness of timbre transfer inputs to the 
loudness of the dataset, so let's calculate it here and save in a pickle file.

In [None]:
from ddsp.local import local_utils
import ddsp.training

data_provider = ddsp.training.data.TFRecordProvider(str(TRAIN_TFRECORD_FILEPATTERN))
dataset = data_provider.get_dataset(shuffle=False)
PICKLE_FILE_PATH = DATA_DIR.joinpath('dataset_statistics.pkl')

_ = local_utils.save_dataset_statistics(data_provider, PICKLE_FILE_PATH, batch_size=1)

Let's load the dataset in the `ddsp` library and have a look at one of the examples.

In [None]:
from ddsp.local import local_utils
from matplotlib import pyplot as plt
import numpy as np

batch_size = 3
sequence_length = 2
audio_rate = 32000
data_rate = 1000

data_provider = ddsp.training.data.TFRecordProvider(str(TRAIN_TFRECORD_FILEPATTERN))
# dataset = data_provider.get_dataset(shuffle=False)
dataset = data_provider.get_batch(batch_size=batch_size, shuffle=False)

dataset_iter = iter(dataset)

audio = np.empty((batch_size,0))
loudness = np.empty((batch_size,0))
f0 = np.empty((batch_size,0))
f0_confidance = np.empty((batch_size,0))
for n in range(sequence_length):
  try:
    ex = next(dataset_iter)
    audio = np.concatenate((audio, ex['audio']), axis=-1)
    loudness = np.concatenate((loudness, ex['loudness_db']), axis=-1)
    f0 = np.concatenate((f0, ex['f0_hz']), axis=-1)
    f0_confidance = np.concatenate((f0_confidance, ex['f0_confidence']), axis=-1)
  except StopIteration:
    raise ValueError(
        'TFRecord contains no examples. Please try re-running the pipeline with '
        'different audio file(s).')
      

for n in range(batch_size):
  local_utils.specplot(audio[n])
  local_utils.play(audio[n])

  f, ax = plt.subplots(3, 1, figsize=(14, 4))
  x = np.linspace(0, 10.0*sequence_length, 2500*sequence_length)
  ax[0].set_ylabel('loudness_db')
  ax[0].plot(x, loudness[n])
  ax[1].set_ylabel('F0_Hz')
  ax[1].set_xlabel('seconds')
  ax[1].plot(x, f0[n])
  ax[2].set_ylabel('F0_confidence')
  ax[2].set_xlabel('seconds')
  ax[2].plot(x, f0_confidance[n])

## Train Model

We will now train a "solo instrument" model. This means the model is conditioned only on the fundamental frequency (f0) and loudness with no instrument ID or latent timbre feature. If you uploaded audio of multiple instruemnts, the neural network you train will attempt to model all timbres, but will likely associate certain timbres with different f0 and loudness conditions. 

First, let's start up a [TensorBoard](https://www.tensorflow.org/tensorboard) to monitor our loss as training proceeds. 

Initially, TensorBoard will report `No dashboards are active for the current data set.`, but once training begins, the dashboards should appear.

In [None]:
# %reload_ext tensorboard
# import tensorboard as tb
# tb.notebook.start('--logdir "{}"'.format(DATA_DIR))

### We will now begin training. 

Note that we specify [gin configuration](https://github.com/google/gin-config) files for the both the model architecture ([solo_instrument.gin](TODO)) and the dataset ([tfrecord.gin](TODO)), which are both predefined in the library. You could also create your own. We then override some of the spefic params for `batch_size` (which is defined in in the model gin file) and the tfrecord path (which is defined in the dataset file). 

### Training Notes:
* Models typically perform well when the loss drops to the range of ~4.5-5.0.
* Depending on the dataset this can take anywhere from 5k-30k training steps usually.
* The default is set to 30k, but you can stop training at any time, and for timbre transfer, it's best to stop before the loss drops too far below ~5.0 to avoid overfitting.
* On the colab GPU, this can take from around 3-20 hours. 
* We **highly recommend** saving checkpoints directly to your drive account as colab will restart naturally after about 12 hours and you may lose all of your checkpoints.
* By default, checkpoints will be saved every 300 steps with a maximum of 10 checkpoints (at ~60MB/checkpoint this is ~600MB). Feel free to adjust these numbers depending on the frequency of saves you would like and space on your drive.
* If you're restarting a session and `DRIVE_DIR` points a directory that was previously used for training, training should resume at the last checkpoint.

In [None]:
GIN_FILE_1 = PROJECT_DIR.joinpath('ddsp', 'ddsp', 'training', 'gin', 'models', 'predictor.gin')
GIN_FILE_2 = PROJECT_DIR.joinpath('ddsp', 'ddsp', 'training', 'gin', 'datasets', 'tfrecord_predictor.gin')

!ddsp_run \
  --mode=train \
  --alsologtostderr \
  --save_dir="$DATA_DIR" \
  --gin_file="$GIN_FILE_1" \
  --gin_file="$GIN_FILE_2" \
  --gin_param="TFRecordProvider.file_pattern=\"$TRAIN_TFRECORD_FILEPATTERN\"" \
  --gin_param="batch_size=16" \
  --gin_param="train_util.train.num_steps=3000000" \
  --gin_param="train_util.train.steps_per_save=300" \
  --gin_param="trainers.Trainer.checkpoints_to_keep=10"

In [None]:
!ddsp_run \
  --mode=eval \
  --alsologtostderr \
  --save_dir="$DATA_DIR" \
  --gin_file="$GIN_FILE_1" \
  --gin_file="$GIN_FILE_2" \
  --gin_param="TFRecordProvider.file_pattern=\"$TEST_TFRECORD_FILEPATTERN\"" \
  --run_once \

In [None]:
from absl import logging
from ddsp.training.ddsp_run import get_gin_path
import gin
from ddsp.training import train_util
from ddsp.training import models
from ddsp.training import trainers
from ddsp.training import cloud
import gin

gfile = tf.io.gfile

GIN_PATH = get_gin_path()

GIN_FILE_1 = str(PROJECT_DIR.joinpath('ddsp', 'ddsp', 'training', 'gin', 'models', 'predictor.gin'))
GIN_FILE_2 = str(PROJECT_DIR.joinpath('ddsp', 'ddsp', 'training', 'gin', 'datasets', 'tfrecord_predictor.gin'))

restore_dir = os.path.expanduser('')
save_dir = os.path.expanduser(DATA_DIR)
# If no separate restore directory is given, use the save directory.
restore_dir = save_dir if not restore_dir else restore_dir
logging.info('Restore Dir: %s', restore_dir)
logging.info('Save Dir: %s', save_dir)

gfile.makedirs(restore_dir)  # Only makes dirs if they don't exist.

# Enable parsing gin files on Google Cloud.
gin.config.register_file_reader(tf.io.gfile.GFile, tf.io.gfile.exists)
# Add user folders to the gin search path.
for gin_search_path in [GIN_PATH] + []:
    gin.add_config_file_search_path(gin_search_path)

# Parse gin configs, later calls override earlier ones.
with gin.unlock_config():
    # Optimization defaults.
    use_tpu = bool('')
    opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
    gin.parse_config_file(os.path.join('optimization', opt_default))
    eval_default = 'eval/basic.gin'
    gin.parse_config_file(eval_default)

    # Load operative_config if it exists (model has already trained).
    try:
        operative_config = train_util.get_latest_operative_config(restore_dir)
        logging.info('Using operative config: %s', operative_config)
        operative_config = cloud.make_file_paths_local(operative_config, GIN_PATH)
        gin.parse_config_file(operative_config, skip_unknown=True)
    except FileNotFoundError:
        logging.info('Operative config not found in %s', restore_dir)

    # User gin config and user hyperparameters from flags.
    gin_params = [
        'TFRecordProvider.file_pattern="' + str(TRAIN_TFRECORD_FILEPATTERN) + '"',
        'batch_size=16',
        'train_util.train.num_steps=3000000',
        'train_util.train.steps_per_save=300',
        'trainers.Trainer.checkpoints_to_keep=10'
    ]
    gin_file = cloud.make_file_paths_local([GIN_FILE_1, GIN_FILE_2], GIN_PATH)
    gin.parse_config_files_and_bindings(
        gin_file, gin_params, skip_unknown=True)

logging.info('Operative Gin Config:\n%s', gin.config.config_str())

# Training.
strategy = train_util.get_strategy('',
                                    cluster_config='')
with strategy.scope():
    model = models.get_model()
    trainer = trainers.get_trainer_class()(model, strategy)

train_util.train(data_provider=gin.REQUIRED,
                    trainer=trainer,
                    save_dir=save_dir,
                    restore_dir=restore_dir,
                    early_stop_loss_value=None,
                    report_loss_to_hypertune=None)

## Resynthesis

Check how well the model reconstructs the training data

In [None]:
from ddsp.local.local_utils import play, specplot
import ddsp.training
import gin
from matplotlib import pyplot as plt
import numpy as np
from datetime import datetime
from ddsp.training.preprocessing import scale_f0_hz, scale_db

TEST_TFRECORD_FILEPATTERN = str(PROJECT_DIR.joinpath('TFRecords', 'chitarra_michele_rossi', 'separate', '9_les_neck_pick', 'train.tfrecord*'))

data_provider = ddsp.training.data.TFRecordProvider(str(TRAIN_TFRECORD_FILEPATTERN), example_secs=10)
dataset = data_provider.get_batch(batch_size=30, shuffle=True)

try:
  batch = next(iter(dataset))
except OutOfRangeError:
  raise ValueError(
      'TFRecord contains no examples. Please try re-running the pipeline with '
      'different audio file(s).')

# Parse the gin config.
# gin_file = DATA_DIR.joinpath('operative_config-0.gin')
gin_file = str(PROJECT_DIR.joinpath('ddsp', 'ddsp', 'training', 'gin', 'models', 'predictor.gin'))
gin.parse_config_file(gin_file)

# Load model
model = ddsp.training.models.Predictor()
model.restore(DATA_DIR)

# print(batch)

# Resynthesize audio.
before = datetime.now()
outputs = model(batch, training=False)
# print(outputs)
after = datetime.now()
# audio_gen = model.get_audio_from_outputs(outputs)
# audio = batch['audio']

time = after - before
print(time)

%matplotlib widget

f, ax = plt.subplots(2, 1, figsize=(14, 4))
x = np.linspace(0, 4.0, 2499)
ax[0].set_ylabel('loudness_db')
ax[0].plot(x, scale_db(batch['loudness_db'][0, :-1]))
ax[0].plot(x, outputs['ld_scaled'][0, 1:])
ax[1].set_ylabel('F0_Hz')
ax[1].set_xlabel('seconds')
ax[1].plot(x, np.squeeze(scale_f0_hz(batch['f0_hz'][0, :-1])))
ax[1].plot(x, np.squeeze(outputs['f0_scaled'][0, 1:]))
# Legend: red = original, blue = resynthesis
ax[0].legend(['original', 'prediction'], loc='upper right')

# print('Original Audio')
# specplot(audio)
# play(audio)

# print('Resynthesis')
# specplot(audio_gen)
# play(audio_gen)