Refs : 

https://www.tensorflow.org/alpha/tutorials/generative/pix2pix

https://www.tensorflow.org/alpha/tutorials/distribute/training_loops

In [1]:
#!pip install -q tensorflow-gpu==2.0.0-alpha0
!pip install -q tf-nightly-gpu-2.0-preview
!pip install -q musdb museval
!apt install ffmpeg

from pathlib import Path
import time

import tensorflow as tf
from tensorflow.python.ops import nn
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.ops import array_ops
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.eager import context
import musdb
import museval
import numpy as np
%load_ext tensorboard

[K     |████████████████████████████████| 349.2MB 63kB/s 
[K     |████████████████████████████████| 61kB 27.1MB/s 
[K     |████████████████████████████████| 430kB 56.0MB/s 
[K     |████████████████████████████████| 3.1MB 32.1MB/s 
[?25h  Building wheel for wrapt (setup.py) ... [?25l[?25hdone
[31mERROR: thinc 6.12.1 has requirement wrapt<1.11.0,>=1.10.0, but you'll have wrapt 1.11.1 which is incompatible.[0m
[K     |████████████████████████████████| 512kB 8.4MB/s 
[K     |████████████████████████████████| 81kB 31.4MB/s 
[?25h  Building wheel for simplejson (setup.py) ... [?25l[?25hdone
Reading package lists... Done
Building dependency tree       
Reading state information... Done
ffmpeg is already the newest version (7:3.4.6-0ubuntu0.18.04.1).
The following package was automatically installed and is no longer required:
  libnvidia-common-410
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 6 not upgraded.


In [2]:
from tensorflow.python.ops import control_flow_util
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
from packaging import version

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."

print("Executing eagerly : {}".format(tf.executing_eagerly()))

TensorFlow version:  2.0.0-dev20190527
Executing eagerly : True


In [3]:
from google.colab import drive
drive.mount('/content/gdrive')
!mkdir -p /content/gdrive/My\ Drive/musdb18
!cp -a /content/gdrive/My\ Drive/musdb18 /content/musdb18

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
!unzip -q -n /content/gdrive/My\ Drive/musdb18.zip -d sample_data/musdb18

mus = musdb.DB(root_dir='sample_data/musdb18')
tracks = mus.load_mus_tracks(subsets=['train'])

In [5]:
# waiting for Cloud TPU 1.14 release https://github.com/tensorflow/tensorflow/issues/24412
#resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
#tf.tpu.experimental.initialize_tpu_system(resolver)
#strategy = tf.distribute.experimental.TPUStrategy(resolver)
strategy = tf.distribute.MirroredStrategy()
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


In [6]:
min_track_length = 569344 # min(tracks, key=lambda t: t.audio.shape[0]).audio.shape[0]
max_track_length = 27711488  # max(tracks, key=lambda t: t.audio.shape[0]).audio.shape[0]
print("max track length = {}".format(min_track_length))
print("max track length = {}".format(max_track_length))
print("min value = {}".format(-1.))  # min(t.audio.min() for t in tracks)))
print("max value = {}".format(1.))  # max(t.audio.max() for t in tracks)))

max track length = 569344
max track length = 27711488
min value = -1.0
max value = 1.0


In [7]:
# ensure the size is a power of two, and no longer than the shortest track
INPUT_SIZE = min(32768, 2**int(np.log2(min_track_length)))
BATCH_SIZE = 16 * strategy.num_replicas_in_sync  # use 128 on TPU
AUTOTUNE = tf.data.experimental.AUTOTUNE
OUTPUT_CHANNELS = tracks[0].audio.shape[1]
LAMBDA = 100
EPOCHS = 500
SAVE_FREQ = 5
LOG_FREQ = 10
VALIDATION_SPLIT = 0.2
n_validation = max(int(len(tracks) * VALIDATION_SPLIT), BATCH_SIZE)  # n_validation needs to be at least BATCH_SIZE or no validation will take place
n_train = len(tracks) - n_validation
validation_steps_per_epoch = n_validation // BATCH_SIZE
train_steps_per_epoch = n_train // BATCH_SIZE

print("INPUT_SIZE : {}".format(INPUT_SIZE))
print("BATCH_SIZE : {}".format(BATCH_SIZE))
print("OUTPUT_CHANNELS : {}".format(OUTPUT_CHANNELS))
print("Validation set size : {}".format(n_validation))
print("Train set size : {}".format(n_train))

INPUT_SIZE : 32768
BATCH_SIZE : 16
OUTPUT_CHANNELS : 2
Validation set size : 20
Train set size : 80


In [0]:
validation_tracks = tracks[:n_validation]
train_tracks = tracks[n_validation:]

In [0]:
def _preprocess_conv1d_input(x, data_format):
  """Transpose and cast the input before the conv1d.
  Arguments:
      x: input tensor.
      data_format: string, `"channels_last"` or `"channels_first"`.
  Returns:
      A tensor.
  """
  tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
  if data_format == 'channels_first':
    if not _has_nchw_support():
      x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
    else:
      tf_data_format = 'NCW'
  return x, tf_data_format

def _preprocess_padding(padding):
  """Convert keras' padding to TensorFlow's padding.
  Arguments:
      padding: string, one of 'same' , 'valid'
  Returns:
      a string, one of 'SAME', 'VALID'.
  Raises:
      ValueError: if invalid `padding'`
  """
  if padding == 'same':
    padding = 'SAME'
  elif padding == 'valid':
    padding = 'VALID'
  else:
    raise ValueError('Invalid padding: ' + str(padding))
  return padding

In [0]:
# ref : tf.keras.backend.conv2d_transpose
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/backend.py

def conv1d_transpose(x,
                     kernel,
                     output_shape,
                     strides=(1,),
                     padding='valid',
                     data_format=None):
  """1D deconvolution (i.e.
  transposed convolution).
  Arguments:
      x: Tensor or variable.
      kernel: kernel tensor.
      output_shape: 1D int tensor for the output shape.
      strides: strides integer.
      padding: string, `"same"` or `"valid"`.
      data_format: string, `"channels_last"` or `"channels_first"`.
          Whether to use Theano or TensorFlow/CNTK data format
          for inputs/kernels/outputs.
      dilation_rate: integer.
  Returns:
      A tensor, result of transposed 1D convolution.
  Raises:
      ValueError: if `data_format` is neither `channels_last` or
      `channels_first`.
  """
  if data_format is None:
    data_format = image_data_format()
  if data_format not in {'channels_first', 'channels_last'}:
    raise ValueError('Unknown data_format: ' + str(data_format))
  if isinstance(output_shape, (tuple, list)):
    output_shape = array_ops.stack(output_shape)

  x, tf_data_format = _preprocess_conv1d_input(x, data_format)

  if data_format == 'channels_first' and tf_data_format == 'NWC':
    output_shape = (output_shape[0], output_shape[2], output_shape[1])
  if output_shape[0] is None:
    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
    output_shape = array_ops.stack(list(output_shape))

  padding = _preprocess_padding(padding)
  if tf_data_format == 'NWC':
    strides = (1,) + strides + (1,)
  else:
    strides = (1, 1) + strides
  x = nn.conv1d_transpose(x, 
                          kernel, 
                          output_shape, 
                          list(strides),
                          padding=padding,
                          data_format=tf_data_format)
  if data_format == 'channels_first' and tf_data_format == 'NWC':
    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
  return x

In [0]:
# ref : tf.keras.layers.Conv2DTranspose
# https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/keras/layers/convolutional.py

class Conv1DTranspose(tf.keras.layers.Conv1D):
  """Transposed convolution layer (sometimes called Deconvolution).
  The need for transposed convolutions generally arises
  from the desire to use a transformation going in the opposite direction
  of a normal convolution, i.e., from something that has the shape of the
  output of some convolution to something that has the shape of its input
  while maintaining a connectivity pattern that is compatible with
  said convolution.
  When using this layer as the first layer in a model,
  provide the keyword argument `input_shape`
  (tuple of integers, does not include the sample axis),
  e.g. `input_shape=(128, 2)` for 128 units long stereo sound
  in `data_format="channels_last"`.
  Arguments:
    filters: Integer, the dimensionality of the output space
      (i.e. the number of output filters in the convolution).
    kernel_size: An integer specifying the width of the 1D
      convolution window.
    strides: An integer specifying the strides of the 
      convolution along the width.
      Specifying any stride value != 1 is incompatible with specifying
      any `dilation_rate` value != 1.
    padding: one of `"valid"` or `"same"` (case-insensitive).
    output_padding: An integer specifying the amount of padding along
      the width of the output tensor.
      The amount of output padding along a given dimension must be
      lower than the stride along that same dimension.
      If set to `None` (default), the output shape is inferred.
    data_format: A string,
      one of `channels_last` (default) or `channels_first`.
      The ordering of the dimensions in the inputs.
      `channels_last` corresponds to inputs with shape
      `(batch, width, channels)` while `channels_first`
      corresponds to inputs with shape
      `(batch, channels, width)`.
      It defaults to the `image_data_format` value found in your
      Keras config file at `~/.keras/keras.json`.
      If you never set it, then it will be "channels_last".
    dilation_rate: an integer specifying
      the dilation rate to use for dilated convolution.
      Currently, specifying any `dilation_rate` value != 1 is
      incompatible with specifying any stride value != 1.
    activation: Activation function to use.
      If you don't specify anything, no activation is applied
      (ie. "linear" activation: `a(x) = x`).
    use_bias: Boolean, whether the layer uses a bias vector.
    kernel_initializer: Initializer for the `kernel` weights matrix.
    bias_initializer: Initializer for the bias vector.
    kernel_regularizer: Regularizer function applied to
      the `kernel` weights matrix.
    bias_regularizer: Regularizer function applied to the bias vector.
    activity_regularizer: Regularizer function applied to
      the output of the layer (its "activation")..
    kernel_constraint: Constraint function applied to the kernel matrix.
    bias_constraint: Constraint function applied to the bias vector.
  Input shape:
    3D tensor with shape:
    `(batch, channels, cols)` if data_format='channels_first'
    or 3D tensor with shape:
    `(batch, cols, channels)` if data_format='channels_last'.
  Output shape:
    3D tensor with shape:
    `(batch, filters, new_cols)` if data_format='channels_first'
    or 3D tensor with shape:
    `(batch, new_cols, filters)` if data_format='channels_last'.
    `cols` value might have changed due to padding.
  References:
    - [A guide to convolution arithmetic for deep
      learning](https://arxiv.org/abs/1603.07285v1)
    - [Deconvolutional
      Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)
  """

  def __init__(self,
               filters,
               kernel_size,
               strides=1,
               padding='valid',
               output_padding=None,
               data_format=None,
               dilation_rate=(1,),
               activation=None,
               use_bias=True,
               kernel_initializer='glorot_uniform',
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
    super().__init__(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding=padding,
        data_format=data_format,
        dilation_rate=dilation_rate,
        activation=tf.keras.activations.get(activation),
        use_bias=use_bias,
        kernel_initializer=tf.keras.initializers.get(kernel_initializer),
        bias_initializer=tf.keras.initializers.get(bias_initializer),
        kernel_regularizer=tf.keras.regularizers.get(kernel_regularizer),
        bias_regularizer=tf.keras.regularizers.get(bias_regularizer),
        activity_regularizer=tf.keras.regularizers.get(activity_regularizer),
        kernel_constraint=tf.keras.constraints.get(kernel_constraint),
        bias_constraint=tf.keras.constraints.get(bias_constraint),
        **kwargs)

    self.output_padding = output_padding
    if self.output_padding is not None:
      self.output_padding = conv_utils.normalize_tuple(
          self.output_padding, 1, 'output_padding')
      for stride, out_pad in zip(self.strides, self.output_padding):
        if out_pad >= stride:
          raise ValueError('Stride ' + str(self.strides) + ' must be '
                           'greater than output padding ' +
                           str(self.output_padding))

  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    if len(input_shape) != 3:
      raise ValueError('Inputs should have rank 3. Received input shape: ' +
                       str(input_shape))
    if self.data_format == 'channels_first':
      channel_axis = 1
    else:
      channel_axis = -1
    if input_shape.dims[channel_axis].value is None:
      raise ValueError('The channel dimension of the inputs '
                       'should be defined. Found None: ' + str(input_shape))
    input_dim = int(input_shape[channel_axis])
    self.input_spec = InputSpec(ndim=3, axes={channel_axis: input_dim})
    kernel_shape = self.kernel_size + (self.filters, input_dim)

    self.kernel = self.add_weight(
        name='kernel',
        shape=kernel_shape,
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        trainable=True,
        dtype=self.dtype)
    if self.use_bias:
      self.bias = self.add_weight(
          name='bias',
          shape=(self.filters,),
          initializer=self.bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint,
          trainable=True,
          dtype=self.dtype)
    else:
      self.bias = None
    self.built = True

  def call(self, inputs):
    inputs_shape = array_ops.shape(inputs)
    batch_size = inputs_shape[0]
    if self.data_format == 'channels_first':
      w_axis = 2
    else:
      w_axis = 1

    width = inputs_shape[w_axis]
    kernel_w = self.kernel_size[0]
    stride_w = self.strides[0]

    if self.output_padding is None:
      out_pad_w = None
    else:
      out_pad_w = self.output_padding[0]

    # Infer the dynamic output shape:
    out_width = conv_utils.deconv_output_length(width,
                                                kernel_w,
                                                padding=self.padding,
                                                output_padding=out_pad_w,
                                                stride=stride_w)
    if self.data_format == 'channels_first':
      output_shape = (batch_size, self.filters, out_width)
    else:
      output_shape = (batch_size, out_width, self.filters)

    output_shape_tensor = array_ops.stack(output_shape)
    outputs = conv1d_transpose(
        inputs,
        self.kernel,
        output_shape_tensor,
        strides=self.strides,
        padding=self.padding,
        data_format=self.data_format)

    if not context.executing_eagerly():
      # Infer the static output shape:
      out_shape = self.compute_output_shape(inputs.shape)
      outputs.set_shape(out_shape)

    if self.use_bias:
      outputs = nn.bias_add(
          outputs,
          self.bias,
          data_format=conv_utils.convert_data_format(self.data_format, ndim=3))

    if self.activation is not None:
      return self.activation(outputs)
    return outputs

  def compute_output_shape(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape).as_list()
    output_shape = list(input_shape)
    if self.data_format == 'channels_first':
      c_axis, w_axis = 1, 2
    else:
      c_axis, w_axis = 2, 1

    kernel_w = self.kernel_size[0]
    stride_w = self.strides[0]

    if self.output_padding is None:
      out_pad_w = None
    else:
      out_pad_w = self.output_padding[0]

    output_shape[c_axis] = self.filters
    output_shape[w_axis] = conv_utils.deconv_output_length(
        output_shape[w_axis],
        kernel_w,
        padding=self.padding,
        output_padding=out_pad_w,
        stride=stride_w)
    return tensor_shape.TensorShape(output_shape)

  def get_config(self):
    config = super().get_config()
    config['output_padding'] = self.output_padding
    config.pop('dilation_rate')
    return config

In [0]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv1D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [0]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    Conv1DTranspose(filters, size, strides=2,
                    padding='same',
                    kernel_initializer=initializer,
                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [14]:
inp = tf.expand_dims(tracks[0].audio, 0)
print(inp.shape)
down_model = downsample(2, (4,))
down_result = down_model(inp)
print(down_result.shape)
up_model = upsample(2, (4,))
up_result = up_model(down_result)
print(up_result.shape)

(1, 7552000, 2)
(1, 3776000, 2)
(1, 7552000, 2)


In [0]:
def Generator():
  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (bs, 65536, 64)
    downsample(128, 4),  # (bs, 32768, 128)
    downsample(256, 4),  # (bs, 16384, 256)
    downsample(512, 4),  # (bs, 8192, 512)
    downsample(512, 4),  # (bs, 4096, 512)
    downsample(512, 4),  # (bs, 2048, 512)
    downsample(512, 4),  # (bs, 1024, 512)
    downsample(512, 4),  # (bs, 512, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (bs, 1024, 1024)
    upsample(512, 4, apply_dropout=True),  # (bs, 2048, 1024)
    upsample(512, 4, apply_dropout=True),  # (bs, 4096, 1024)
    upsample(512, 4),  # (bs, 8192, 1024)
    upsample(256, 4),  # (bs, 16384, 512)
    upsample(128, 4),  # (bs, 32768, 256)
    upsample(64, 4),  # (bs, 65536, 128)
  ]
  
  initializer = tf.random_normal_initializer(0.02)
  last = Conv1DTranspose(OUTPUT_CHANNELS, 4,
                         strides=2,
                         padding='same',
                         kernel_initializer=initializer,
                         activation='tanh')  # (bs, 131072(=INPUT_SIZE), 2)

  concat = tf.keras.layers.Concatenate()

  inputs = tf.keras.layers.Input(shape=[None, OUTPUT_CHANNELS])
  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)


In [0]:
def Discriminator():
  initializer = tf.random_normal_initializer(0.02)

  inp = tf.keras.layers.Input(shape=[None, OUTPUT_CHANNELS], name='input_image')
  tar = tf.keras.layers.Input(shape=[None, OUTPUT_CHANNELS], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (128, 131072(=INPUT_SIZE), 4)

  down1 = downsample(64, 4, False)(x)  # (128, 65536, 64)
  down2 = downsample(128, 4)(down1)  # (128, 32768, 128)
  down3 = downsample(256, 4)(down2)  # (128, 16384, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding1D()(down3)  # (128, 16386, 256)
  conv = tf.keras.layers.Conv1D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (128, 16383, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding1D()(leaky_relu)  # (128, 16385, 512)

  last = tf.keras.layers.Conv1D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (128, 16382, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [17]:
# MODEL OBJECTS
with strategy.scope():
  generator = Generator()
  discriminator = Discriminator()
  generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
  discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

W0527 16:41:58.270230 140503906137984 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py:617: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [0]:
with strategy.scope():
  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)  # reduction because of parallel strategy
  train_gen_loss = tf.keras.metrics.Mean(name='train_gen_loss')
  validation_gen_loss = tf.keras.metrics.Mean(name='validation_gen_loss')
  train_disc_loss = tf.keras.metrics.Mean(name='train_disc_loss')
  validation_disc_loss = tf.keras.metrics.Mean(name='validation_disc_loss')


In [0]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output) * (1. / BATCH_SIZE)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output) * (1. / BATCH_SIZE)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [0]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output) * (1. / BATCH_SIZE)

  # mean absolute error
  # don't use reduce_mean because of parallel strategy
  l1_loss = tf.reduce_sum(tf.abs(target - gen_output)) * (1. / BATCH_SIZE)

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

In [0]:
# compute a batch of INPUT_SIZE samples
with strategy.scope():
  def get_buffered_data(data, eval_accessor=""):
      return tf.cast(
          tf.concat(
              [
                  tf.split(
                      (eval("t{}".format(eval_accessor)) if eval_accessor else t).audio[
                          : t.audio.shape[0] - t.audio.shape[0] % INPUT_SIZE
                      ],
                      [INPUT_SIZE] * (t.audio.shape[0] // INPUT_SIZE),
                      axis=0,
                  )
                  for t in data
              ],
              axis=0,
          ),
          dtype=tf.float32,
      )

In [0]:
# Data Generators
with strategy.scope():
  def dataset_gen(n_data, data):
    for i in range(0, n_data, 5):
      tracks_sample = data[i:i+5]
      sample_dataset = tf.stack((
          get_buffered_data(tracks_sample), 
          get_buffered_data(tracks_sample, ".targets['vocals']")
      ), axis=1)
      for sample in sample_dataset:
        yield tuple(tf.unstack(sample))
        
  def train_dataset_gen():
    return dataset_gen(n_train, train_tracks)
        
  def validation_dataset_gen():
    return dataset_gen(n_validation, validation_tracks)

In [23]:
#print(next(train_dataset_gen()))
print(tf.TensorShape([INPUT_SIZE, OUTPUT_CHANNELS]))

(32768, 2)


In [24]:
# DATASET
with strategy.scope():
  sample_shape = tf.TensorShape([INPUT_SIZE, OUTPUT_CHANNELS])
  train_dataset = tf.data.Dataset.from_generator(
      train_dataset_gen, 
      output_types=(tf.float32, tf.float32),
      output_shapes=(sample_shape, sample_shape)
  )
  validation_dataset = tf.data.Dataset.from_generator(
      validation_dataset_gen, 
      output_types=(tf.float32, tf.float32),
      output_shapes=(sample_shape, sample_shape)
  )
  
  train_dataset = (train_dataset
                  .shuffle(buffer_size=n_train)
                  .repeat()
                  .batch(BATCH_SIZE, drop_remainder=True)
                  .prefetch(AUTOTUNE)
                  )
                   
  validation_dataset = (validation_dataset
                       .repeat()
                       .batch(BATCH_SIZE, drop_remainder=True)
                       .prefetch(AUTOTUNE)
                       )
  
  train_iterator = strategy.make_dataset_iterator(train_dataset)
  validation_iterator = strategy.make_dataset_iterator(validation_dataset)
                   
  print(train_dataset)
  print(validation_dataset)

W0527 16:42:01.619108 140503906137984 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py:499: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version.
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, there are two
    options available in V2.
    - tf.py_function takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    - tf.numpy_function maintains the semantics of the deprecated tf.py_func
    (it is not differentiable, and manipulates numpy arrays). It drops the
    stateful argument making all functions stateful.
    


<DatasetV1Adapter shapes: ((16, 32768, 2), (16, 32768, 2)), types: (tf.float32, tf.float32)>
<DatasetV1Adapter shapes: ((16, 32768, 2), (16, 32768, 2)), types: (tf.float32, tf.float32)>


In [0]:
def test_models_are_working():
  test_data = validation_iterator.get_next()
  validation_iterator.initialize()
  print(test_data)
  mixture, target = test_data
  gen_output = generator(mixture, training=False)
  print(gen_output)

  disc_real_output = discriminator([mixture, target], training=False)
  print(disc_real_output)
  disc_generated_output = discriminator([mixture, gen_output], training=False)
  print(disc_generated_output)

  print("gen_loss : {}".format(generator_loss(disc_generated_output, gen_output, target)))
  print("disc_loss : {}".format(discriminator_loss(disc_real_output, disc_generated_output)))

# test_models_are_working()

In [26]:
# CHECKPOINTS
checkpoint_dir = '/content/musdb18/checkpoints/'

with strategy.scope():
  checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                   discriminator_optimizer=discriminator_optimizer,
                                   generator=generator,
                                   discriminator=discriminator)
  ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
  
  if ckpt_manager.latest_checkpoint:
    checkpoint.restore(ckpt_manager.latest_checkpoint)
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")


Restored from /content/musdb18/checkpoints/ckpt-66


In [0]:
# TRAIN STEP
with strategy.scope():
  def train_step(inputs):
    mixture, target = inputs
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      gen_output = generator(mixture, training=True)

      disc_real_output = discriminator([mixture, target], training=True)
      disc_generated_output = discriminator([mixture, gen_output], training=True)

      gen_loss = generator_loss(disc_generated_output, gen_output, target)
      disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_loss,
                                            generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                 discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                            generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                discriminator.trainable_variables))
    
    train_gen_loss.update_state(gen_loss)
    train_disc_loss.update_state(disc_loss)
    
  def validation_step(inputs):
    mixture, target = inputs
    gen_output = generator(mixture, training=False)
    
    disc_real_output = discriminator([mixture, target], training=False)
    disc_generated_output = discriminator([mixture, gen_output], training=False)
    
    gen_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    validation_gen_loss.update_state(gen_loss)
    validation_disc_loss.update_state(disc_loss)



In [0]:
with strategy.scope():
  # `experimental_run` replicates the provided computation and runs it
  # with the distributed input.

  @tf.function
  def distributed_train():
    return strategy.experimental_run(train_step, train_iterator)
  
  @tf.function
  def distributed_validation():
    return strategy.experimental_run(validation_step, validation_iterator)


In [0]:
# LEARNING RATE CALLBACK
def get_symbolic_metric(metric_obj):  # needed by add_metric, see keras.metrics.Metric.__call__
      result_t = metric_obj.result()
      result_t._metric_obj = metric_obj
      return result_t

with strategy.scope():
  generator.optimizer = generator_optimizer
  discriminator.optimizer = discriminator_optimizer
  generator.stop_training = False
  discriminator.stop_training = False
  
  gen_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=train_gen_loss.name, patience=10, verbose=True)
  gen_lr_callback.set_model(generator)
  disc_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=train_disc_loss.name, patience=10, verbose=True)
  disc_lr_callback.set_model(discriminator)
  
  gen_earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor=train_gen_loss.name, patience=30, restore_best_weights=True, verbose=True)
  gen_earlystop_callback.set_model(generator)
  disc_earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor=train_disc_loss.name, patience=30, restore_best_weights=True, verbose=True)
  disc_earlystop_callback.set_model(discriminator)

In [0]:
# TENSORBOARD
logs_dir = '/content/musdb18/logs/'
train_summary_writer = tf.summary.create_file_writer(logs_dir+'train')
validation_summary_writer = tf.summary.create_file_writer(logs_dir+'validation')

In [0]:
# TENSORBOARD SUMMARY
test_mixture, test_target = validation_iterator.get_next()
test_mixture, test_target = tf.expand_dims(test_mixture[0], 0), tf.expand_dims(test_target[0], 0)
validation_iterator.initialize()
  
def update_summary(epoch):
  with train_summary_writer.as_default():
    tf.summary.scalar('gen_loss', train_gen_loss.result(), step=epoch)
    tf.summary.scalar('disc_loss', train_gen_loss.result(), step=epoch)
    tf.summary.scalar('gen_lr', generator.optimizer.lr, step=epoch)
    tf.summary.scalar('disc_lr', discriminator.optimizer.lr, step=epoch)
  gen_output = tf.expand_dims(generator(test_mixture, training=False)[0], 0)
  with validation_summary_writer.as_default():
    tf.summary.scalar('gen_loss', validation_gen_loss.result(), step=epoch)
    tf.summary.scalar('disc_loss', validation_disc_loss.result(), step=epoch)
    
    tf.summary.audio('audio_separation',
                     tf.concat([test_mixture, gen_output, test_target], axis=0),
                     44100,
                     step=epoch,
                     max_outputs=3,
                     encoding="wav"
                    )
  #!rsync -a --delete $logs_dir /content/gdrive/My\ Drive/musdb18/logs

In [0]:
with strategy.scope():
  gen_lr_callback.on_train_begin()
  disc_lr_callback.on_train_begin()
  gen_earlystop_callback.on_train_begin()
  disc_earlystop_callback.on_train_begin()
  
  for epoch in range(EPOCHS):
    start = time.time()

    #train_iterator.initialize()
    for _ in range(train_steps_per_epoch):
      distributed_train()

    #validation_iterator.initialize()
    for _ in range(validation_steps_per_epoch):
      distributed_validation()

    # saving (checkpoint) the model every SAVE_FREQ epochs
    if tf.equal((epoch+1) % SAVE_FREQ, 0):
      ckpt_manager.save()
      #!rsync -a --delete $checkpoint_dir /content/gdrive/My\ Drive/musdb18/checkpoints
      
    if tf.equal((epoch+1) % LOG_FREQ, 0):
      update_summary(epoch)
    
    logs = {
        train_gen_loss.name: train_gen_loss.result(),
        train_disc_loss.name: train_disc_loss.result(),
        validation_gen_loss.name: validation_gen_loss.result(),
        validation_disc_loss.name: validation_disc_loss.result()
    }
    gen_lr_callback.on_epoch_end(epoch, logs)
    disc_lr_callback.on_epoch_end(epoch, logs)
    gen_earlystop_callback.on_epoch_end(epoch, logs)
    disc_earlystop_callback.on_epoch_end(epoch, logs)

    template = ("Epoch {} ({} sec)\nGen Loss: {}, Disc Loss: {}, "
                "Validation Gen Loss: {}, Validation Disc Loss: {}")
    tf.print(template.format(epoch+1, time.time()-start, 
                          logs[train_gen_loss.name], 
                          logs[train_disc_loss.name], 
                          logs[validation_gen_loss.name], 
                          logs[validation_disc_loss.name]))
    if gen_earlystop_callback.model.stop_training and disc_earlystop_callback.model.stop_training:
      break

    train_gen_loss.reset_states()
    train_disc_loss.reset_states()
    validation_gen_loss.reset_states()
    validation_disc_loss.reset_states()
  gen_earlystop_callback.on_train_end()
  disc_earlystop_callback.on_train_end()

In [0]:
%tensorboard --logdir $logs_dir

In [0]:
# EVALUATION FUNCS
estimates_dir = "/content/musdb18/estimates/"
evaluation_scores_dir = "/content/musdb18/scores/"

def is_power2(num):
  return num and not num & (num - 1)

def preprocess_data(audio):
  audio = tf.cast(audio, dtype=tf.float32)
  audio_len = audio.shape[0]
  if audio_len <= INPUT_SIZE:
    # pad audio to nearest bigger power of 2, because upsample and downsample are not symmetric with length that are not power of 2
    if not is_power2(audio_len):
      new_width = 2**(int(np.log2(audio_len))+1)
      paddings = new_width - audio_len
      audio = tf.pad(audio, [(0, paddings), (0, 0)])
    audio = tf.expand_dims(audio, 0)
    return audio
  paddings = INPUT_SIZE - (audio_len % INPUT_SIZE)
  audio = tf.pad(audio, [(0, paddings), (0, 0)])
  audio = tf.concat([tf.split(
      audio,
      [INPUT_SIZE] * (audio.shape[0] // INPUT_SIZE),
      axis=0,
  )], 0)
  return audio
  
BATCH_SIZE = 64
def postprocess_data(gen_output, original_width):
  return tf.reshape(tf.concat(gen_output, 0), [-1, OUTPUT_CHANNELS])[:original_width].numpy()
      
def test_func(track):
  data = preprocess_data(track.audio)
  res = []
  for i in range(0, data.shape[0], BATCH_SIZE):
    res.append(generator(data[i:i+BATCH_SIZE], training=False))
  data = postprocess_data(res, track.audio.shape[0])
  data = {'vocals': data, 'accompaniment': track.audio - data}
  
  # Evaluate using museval, use EVALUATION ONLY to separate process
  #scores = museval.eval_mus_track(
  #    track, data, output_dir=evaluation_scores_dir
  #)
  
  return data

In [31]:
# EVALUATE
# mus.test(test_func)
mus.run(test_func, subsets="test", estimates_dir=estimates_dir)
#!rsync -a --delete $estimates_dir /content/gdrive/My\ Drive/musdb18/estimates
#!rsync -a --delete $evaluation_scores_dir /content/gdrive/My\ Drive/musdb18/scores

100%|██████████| 50/50 [26:24<00:00, 26.19s/it]


[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [0]:
# EVALUATION ONLY
#!rm -rf $evaluation_scores_dir
#museval.eval_mus_dir(
#    dataset=mus,  # instance of musdb
#    estimates_dir=estimates_dir,  # path to estimate folder
#    output_dir=evaluation_scores_dir,  # set a folder to write eval json files
#    subsets="test",
#    parallel=True
#)

In [0]:
# DISPLAY RESULTS
import json
from statistics import median
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

SDR, SIR, SAR, ISR = [], [], [], []
for file in (Path(evaluation_scores_dir) / 'test').iterdir():
  with open(file, 'r') as f:
    data = json.loads(f.read())
    frames = data['targets'][0]['frames']
    SDR.append(median([frame['metrics']['SDR'] for frame in frames if not math.isnan(frame['metrics']['SDR'])]))
    SIR.append(median([frame['metrics']['SIR'] for frame in frames if not math.isnan(frame['metrics']['SIR'])]))
    SAR.append(median([frame['metrics']['SAR'] for frame in frames if not math.isnan(frame['metrics']['SAR'])]))
    ISR.append(median([frame['metrics']['ISR'] for frame in frames if not math.isnan(frame['metrics']['ISR'])]))
    
fig = plt.figure()

ax = plt.subplot(1, 4, 1) # (rows, columns, panel number)
ax.set_title('SDR')
plt.boxplot(SDR, vert=False, notch=True, showfliers=False)

ax = plt.subplot(1, 4, 2)
ax.set_title('SIR')
plt.boxplot(SIR, vert=False, notch=True, showfliers=False)

ax = plt.subplot(1, 4, 3)
ax.set_title('SAR')
plt.boxplot(SAR, vert=False, notch=True, showfliers=False)

ax = plt.subplot(1, 4, 4)
ax.set_title('ISR')
plt.boxplot(ISR, vert=False, notch=True, showfliers=False)
fig.savefig(Path(evaluation_scores_dir) / 'result.png')