# Mapping model training

## Setup Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Install Dependencies

First we install the required dependencies with `pip`.

In [90]:
%tensorflow_version 2.x
!pip install -qU ddsp[data_preparation]==1.0.1

## Make directories to save model and data

In [None]:
import os

drive_dir = '/content/drive/My Drive/nsynth_guitar'
checkpoint_dir = os.path.join(drive_dir, 'mapping/checkpoint')

assert os.path.exists(drive_dir)
print('Drive Directory Exists:', drive_dir)

!mkdir -p "$checkpoint_dir"

## Clear existing checkpoints

In [92]:
# import shutil

# try:
#     shutil.rmtree(checkpoint_dir)
# except OSError as e:
#     print("Error: %s : %s" % (checkpoint_dir, e.strerror))

### Download Complete NSynth Guitar Subset

In [93]:
dataset_dir = '/content/complete'
train_dataset_dir = os.path.join(dataset_dir, 'train')
valid_dataset_dir = os.path.join(dataset_dir, 'valid')
test_dataset_dir = os.path.join(dataset_dir, 'test')

train_tfrecord_file = os.path.join(train_dataset_dir, 'complete.tfrecord')
valid_tfrecord_file = os.path.join(valid_dataset_dir, 'complete.tfrecord')
test_tfrecord_file = os.path.join(test_dataset_dir, 'complete.tfrecord')

if not os.path.exists(dataset_dir):
  train = 'https://osr-tsoai.s3.amazonaws.com/complete/train/complete.tfrecord'
  valid = 'https://osr-tsoai.s3.amazonaws.com/complete/valid/complete.tfrecord'
  test = 'https://osr-tsoai.s3.amazonaws.com/complete/test/complete.tfrecord'

  print("Downloading train dataset to {}\n".format(train_dataset_dir))
  !mkdir -p "$train_dataset_dir"
  !curl $train --output $train_tfrecord_file

  print("\nDownloading valid dataset to {}\n".format(valid_dataset_dir))
  !mkdir -p "$valid_dataset_dir"
  !curl $valid --output $valid_tfrecord_file

  print("\nDownloading test dataset to {}\n".format(test_dataset_dir))
  !mkdir -p "$test_dataset_dir"
  !curl $test --output $test_tfrecord_file

## Define DataProvider class

In [94]:
import tensorflow as tf
import ddsp.training.data as data

class CompleteTFRecordProvider(data.RecordProvider):
  def __init__(self,
               file_pattern=None,
               example_secs=4,
               sample_rate=16000,
               frame_rate=250,
               map_func=None):
    super().__init__(file_pattern, example_secs, sample_rate,
                      frame_rate, tf.data.TFRecordDataset)
    self._map_func = map_func

  def get_dataset(self, shuffle=True):
    def parse_tfexample(record):
      features = tf.io.parse_single_example(record, self.features_dict)
      if self._map_func is not None:
        return self._map_func(features)
      else:
        return features

    filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=shuffle)
    dataset = filenames.interleave(
        map_func=self._data_format_map_fn,
        cycle_length=40,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.map(parse_tfexample,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset

  @property
  def features_dict(self):
    return {
      'sample_name':
        tf.io.FixedLenFeature([1], dtype=tf.string),
      'note_number':
        tf.io.FixedLenFeature([1], dtype=tf.int64),
      'velocity':
        tf.io.FixedLenFeature([1], dtype=tf.int64),
      'instrument_source':
        tf.io.FixedLenFeature([1], dtype=tf.int64),
      'qualities':
        tf.io.FixedLenFeature([10], dtype=tf.int64),
      'audio':
        tf.io.FixedLenFeature([self._audio_length], dtype=tf.float32),
      'f0_hz':
        tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
      'f0_confidence':
        tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
      'loudness_db':
        tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
      'f0_scaled':
        tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
      'ld_scaled':
        tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
      'z':
        tf.io.FixedLenFeature([self._feature_length * 16], dtype=tf.float32),
    }

## Define features map function

In [95]:
def features_map(features):
  note_number = features['note_number']
  velocity = features['velocity']
  instrument_source = features['instrument_source']
  qualities = features['qualities']
  f0_scaled = features['f0_scaled']
  ld_scaled = features['ld_scaled']
  z = features['z']

  sequence_length = f0_scaled.shape[0]

  def convert_to_sequence(feature):
    channels = feature.shape[0]
    feature = tf.expand_dims(feature, axis=0)

    feature = tf.broadcast_to(feature, shape=(sequence_length, channels))
    feature = tf.cast(feature, dtype=tf.float32)
    
    return feature

  # Normalize data
  # 0-127
  note_number = note_number / 127
  velocity = velocity / 127

  # 0-2
  # 0	acoustic, 1	electronic, 2	synthetic
  instrument_source = instrument_source / 2

  # Prepare dataset for a sequence to sequence mapping
  note_number = convert_to_sequence(note_number)
  velocity = convert_to_sequence(velocity)
  instrument_source = convert_to_sequence(instrument_source)
  qualities = convert_to_sequence(qualities)

  f0_scaled = tf.expand_dims(f0_scaled, axis=-1)
  ld_scaled = tf.expand_dims(ld_scaled, axis=-1)
  z = tf.reshape(z, shape=(sequence_length, 16))

  input = tf.concat(
      [note_number, velocity, instrument_source, qualities, z],
      axis=-1)
  
  output = tf.concat(
      [f0_scaled, ld_scaled],
      axis=-1)
  
  return (input, output)

## Create datasets

In [96]:
batch_size = 16
example_secs = 4
sample_rate = 16000
frame_rate = 250

# Create train dataset
train_data_provider = CompleteTFRecordProvider(
    file_pattern=train_tfrecord_file + '*',
    example_secs=example_secs,
    sample_rate=sample_rate,
    frame_rate=frame_rate,
    map_func=features_map)

train_dataset = train_data_provider.get_batch(
    batch_size,
    shuffle=True,
    repeats=-1)

# Create valid dataset
valid_data_provider = CompleteTFRecordProvider(
    file_pattern=train_tfrecord_file + '*',
    example_secs=example_secs,
    sample_rate=sample_rate,
    frame_rate=frame_rate,
    map_func=features_map)

valid_dataset = valid_data_provider.get_batch(
    batch_size,
    shuffle=True,
    repeats=-1)

# Create test dataset
test_data_provider = CompleteTFRecordProvider(
    file_pattern=train_tfrecord_file + '*',
    example_secs=example_secs,
    sample_rate=sample_rate,
    frame_rate=frame_rate,
    map_func=features_map)

test_dataset = test_data_provider.get_batch(
    batch_size,
    shuffle=True,
    repeats=-1)

# Create and compile mapping model

In [97]:
model = tf.keras.models.Sequential([
    tf.keras.layers.GRU(32, return_sequences=True),
    tf.keras.layers.Dense(2, activation='tanh')
])

loss = tf.keras.losses.MeanAbsoluteError()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    loss=loss,
    metrics=[tf.keras.losses.MeanSquaredError()])

## Build model

In [None]:
x_train, y_train = next(iter(train_dataset))
out = model(x_train)

print(model.summary())

# Load model checkpoint

In [99]:
checkpoint_file = os.path.join(checkpoint_dir, 'cp.ckpt')

if os.path.isdir(checkpoint_dir) and os.listdir(checkpoint_dir):
    model.load_weights(checkpoint_file)

## Create training callbacks

In [100]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_file,
    save_weights_only=True,
    verbose=0,
    save_freq='epoch')

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

def scheduler(epoch, lr):
  if epoch < 10:
    return lr
  else:
    return lr * 0.9

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

## Train the model

In [None]:
epochs = 20
steps_per_epoch = 100
validation_steps = 10

model.fit(train_dataset,
          epochs=epochs,
          steps_per_epoch=steps_per_epoch,
          validation_data=valid_dataset,
          validation_steps=validation_steps,
          callbacks=[checkpoint, early_stop, lr_scheduler])

## Evaluate model on test dataset

In [None]:
model.evaluate(test_dataset, steps=500)