In [1]:
import librosa
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
import tensorflow_addons as tfa

from datetime import datetime
import os
import sys

module_path = os.path.abspath(os.path.join(os.pardir))
if module_path not in sys.path:
    sys.path.append(module_path)
from utils.helper import get_filenames, istft
from utils.dataset import tfrecord2dataset

## Preprocessing
### Parameter configuration for STFT

In [2]:
# parameter config
N_FFT = 4096
HOP_LEN = 1024
WIN_LEN = 4096
FREQ_BINS = 2049
TIME_FRAMES = 87 
SR = 44100
DURATION = 2.0

### Load training dataset

In [3]:
root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
data_dir = os.path.join(root, 'data')
# training data
train_data_dir = os.path.join(data_dir, 'dsd100_train_tfrecords')
train_tfrecords_zipfiles = get_filenames(train_data_dir+'/*')
train_dataset = tfrecord2dataset(train_tfrecords_zipfiles)
# test_data
test_data_dir = os.path.join(data_dir, 'dsd100_test_tfrecords')
test_tfrecords_zipfiles = get_filenames(test_data_dir+'/*')
test_dataset = tfrecord2dataset(test_tfrecords_zipfiles)

## Audio Separator Model
### 1D transposed convolution

In [4]:
# 1d transposed convolution
class Conv1DTranspose(layers.Layer):
    def __init__(self,
                 filters,
                 kernel_size,
                 strides=1,
                 name=None,
                 **kwargs):
        super(Conv1DTranspose, self).__init__(name=name, **kwargs)
        self.expand_dim = layers.Lambda(lambda x: tf.expand_dims(x, axis=2))
        self.conv2dtranspose = layers.Conv2DTranspose(filters, 
                                                      (kernel_size, 1), 
                                                      (strides, 1),
                                                      padding='same')
        self.squeeze_dim = layers.Lambda(lambda x: tf.squeeze(x, axis=2))
        self.activation = layers.LeakyReLU(alpha=0.01)
    
    @tf.function
    def call(self, inputs):
        x = self.expand_dim(inputs)
        x = self.conv2dtranspose(x)
        x = self.squeeze_dim(x)
        x = self.activation(x)
        return x
    
# upsampling + conv1D
class UpConv1D(layers.Layer):
    def __init__(self,
                 filters,
                 kernel_size,
                 strides,
                 name=None,
                 **kwargs):
        super(UpConv1D, self).__init__(name=name, **kwargs)
        self.upsampling = layers.UpSampling1D(size=2)
        self.conv1d = layers.Conv1D(filters, kernel_size, strides, padding='same')
        self.activation = layers.LekayReLU(alpha=0.01)
    
    @tf.function
    def call(self, inputs):
        x = self.upsampling(inputs)
        x = self.conv1d(x)
        x = self.activation(x)
        return x

### Customized layer and model

In [5]:
class Encoder(layers.Layer):
    def __init__(self, 
                 frequency_bins, 
                 time_frames, 
                 name='encoder',
                 **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.freq_bins = frequency_bins
        self.time_frames = time_frames
        self.conv1 = tfa.layers.WeightNormalization(
                layers.Conv1D(filters=self.freq_bins // 2, kernel_size=3, padding='same'),
                layers.LeakyReLU(alpha=0.01))
        self.conv2 = tfa.layers.WeightNormalization(
                layers.Conv1D(filters=self.freq_bins // 4, kernel_size=3, strides=1, padding='same'),
                layers.LeakyReLU(alpha=0.01))
        self.conv3 = tfa.layers.WeightNormalization(
                layers.Conv1D(filters=self.freq_bins // 8, kernel_size=3, strides=1, padding='same'),
                layers.LeakyReLU(alpha=0.01))

    @tf.function
    def call(self, inputs):
        reshaped_inputs = layers.Reshape((self.time_frames, self.freq_bins))(inputs)
        conv1 = self.conv1(reshaped_inputs)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        return conv1, conv2, conv3

    
class Decoder(layers.Layer):
    def __init__(self, 
                 frequency_bins, 
                 time_frames,
                 name='decoder_skip_conn',
                 **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.freq_bins = frequency_bins
        self.time_frames = time_frames
        self.tconv4 = Conv1DTranspose(filters=self.freq_bins // 4, kernel_size=3)
        self.tconv5 = Conv1DTranspose(filters=self.freq_bins // 2, kernel_size=3)
        self.tconv6 = Conv1DTranspose(filters=self.freq_bins, kernel_size=3)
    
    @tf.function
    def call(self, inputs):
        conv1, conv2, conv3 = inputs
        # 1st deconvolution layer with skip connection
        tconv4 = self.tconv4(conv3)
        tconv4_output = layers.Add(name='skip_conn1')([tconv4, conv2])
        # 2nd deconvolution layer with skip connection
        tconv5 = self.tconv5(tconv4_output)
        tconv5_output = layers.Add(name='skip_conn2')([tconv5, conv1])
        # output deconvolution layer
        tconv6 = self.tconv6(tconv5_output)
        output = layers.Reshape((self.freq_bins, self.time_frames))(tconv6)
        return output

class DenoisingAutoencoder(keras.Model):
    def __init__(self,
                 frequency_bins, 
                 time_frames, 
                 name='denoising_autoencoder',
                 **kwargs):
        super(DenoisingAutoencoder, self).__init__(name=name, **kwargs)
        self.encoder = Encoder(frequency_bins, time_frames)
        self.decoder = Decoder(frequency_bins, time_frames)
    
    @tf.function
    def call(self, inputs):
        latent = self.encoder(inputs)
        reconstructed = self.decoder(latent)
        return reconstructed
    

class AutoencoderExpandDim(keras.Model):
    def __init__(self,
                 frequency_bins, 
                 time_frames, 
                 name='autoencoder_expand_output_dim',
                 **kwargs):
        super(AutoencoderExpandDim, self).__init__(name=name, **kwargs)
        self.encoder = Encoder(frequency_bins, time_frames)
        self.decoder = Decoder(frequency_bins, time_frames)
        self.expand_dim = keras.Sequential(
            layers=[
                layers.Reshape((time_frames, frequency_bins)),
                Conv1DTranspose(filters=frequency_bins*4, kernel_size=3),
                layers.Reshape((frequency_bins*4, time_frames))
            ]
        )
    
    @tf.function
    def call(self, inputs):
        latent = self.encoder(inputs)
        reconstructed = self.decoder(latent)
        output = self.expand_dim(reconstructed)
        return output

In [6]:
class Separator(keras.Model):
    def __init__(self,
                 frequency_bins, 
                 time_frames, 
                 name='Denoising_autoencoder_separator',
                 **kwargs):
        super(Separator, self).__init__(name=name, **kwargs)
        self.freq_bins = frequency_bins
        self.time_frames = time_frames

        # autoencoder for reconstruction, output with expanded dimension (4*frequency_bins)
        self.reconstruction = AutoencoderExpandDim(frequency_bins, time_frames)
        # denoising autoencoder separator for STEM
        self.vocals_sep = DenoisingAutoencoder(frequency_bins, time_frames)
        self.bass_sep = DenoisingAutoencoder(frequency_bins, time_frames)
        self.drums_sep = DenoisingAutoencoder(frequency_bins, time_frames)
        self.other_sep = DenoisingAutoencoder(frequency_bins, time_frames)
    
    @tf.function
    def call(self, inputs):
        recon_expand_dim = self.reconstruction(inputs)
        vocals = self.vocals_sep(recon_expand_dim[:self.freq_bins])
        bass = self.bass_sep(recon_expand_dim[self.freq_bins:self.freq_bins*2])
        drums = self.drums_sep(recon_expand_dim[self.freq_bins*2:self.freq_bins*3])
        other = self.other_sep(recon_expand_dim[self.freq_bins*3:])
        return tf.stack([vocals, bass, drums, other], axis=1)

## Training

In [None]:
# train config
EPOCHS = 200
lr = 0.0001
BATCH_SIZE = 256

# learning rate decay function
lr_fn = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=lr,
    decay_steps=10000,
    decay_rate=0.9,
    staircase=True)
# early stopping
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', min_delta=1e-3, verbose=True, patience=5)

model = Separator(frequency_bins=FREQ_BINS, time_frames=TIME_FRAMES)
optimizer = keras.optimizers.Adam(learning_rate=0.0001)
mse_loss = keras.losses.MeanSquaredError()

model.compile(optimizer, loss=mse_loss)
history = model.fit(train_dataset, 
                    steps_per_epoch = 6488 // BATCH_SIZE,
                    epochs=EPOCHS,
                    callbacks=[callback])

In [None]:
model.summary()

### Save model

In [None]:
saved_model_dir = os.path.join(root, "notebook", "dae_expand_dimension_separator")
if not os.path.exists(saved_model_dir):
    os.mkdir(saved_model_dir)
# save model
date_time = datetime.now().strftime("%Y-%m-%d_%H:%Mjj")
saved_model_path = os.path.join(saved_model_dir, "/dae_expand_dim_{}".format(date_time))
tf.saved_model.save(model, saved_model_path)

In [None]:
train_loss_results=[]
train_accuracy_results=[]
epoch_loss_history = keras.metrics.Mean()

EPOCHS = 150
for epoch in range(EPOCHS):
    print('START of Epoch %d' % (epoch,))
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            y_pred = model(x_batch_train)
            loss = tf.reduce_mean(mse_loss_fn(y_pred, y_batch_train))
        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        # add current batch loss
        epoch_loss_avg(loss)
        if step % 10 == 0:
            print('step {}: mean loss = {}'.format(step, epoch_loss_history.result()))
    
    # after each epoch
    train_loss_results.append(epoch_loss_history.result())
    if epoch % 50 == 0:
        print("Epoch {:03d}: Loss: {:.3f}".format(epoch, epoch_loss_history.result()))
print('END')

In [None]:
# Visualization
fig, axes = plt.subplots(2, sharex=True, figsize=(12, 8))
fig.suptitle('Training Metrics')

axes[0].set_ylabel("Loss", fontsize=14)
axes[0].plot(train_loss_results)

axes[1].set_ylabel("Accuracy", fontsize=14)
axes[1].set_xlabel("Epoch", fontsize=14)
axes[1].plot(train_accuracy_results)
plt.show()