In [None]:
import os
import sys


module_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))  
if module_path not in sys.path:       
    sys.path.append(module_path)

In [None]:
import numpy as np
import tensorflow as tf
import IPython.display as ipd
from utils.helper import wav_to_spectrogram_clips, rebuild_audio_from_spectro_clips
from utils.dataset import create_samples
from models.conv_denoising_unet import ConvDenoisingUnet
from training.plot import plot_curve, plot_learning_curves

In [None]:
samples = create_samples('Dev')
train_sample = samples[0]

x_train = wav_to_spectrogram_clips(train_sample['mix'])
y_train = dict()
y_train['vocals'] = wav_to_spectrogram_clips(train_sample['vocals'])
y_train['bass'] = wav_to_spectrogram_clips(train_sample['bass'])
y_train['drums'] = wav_to_spectrogram_clips(train_sample['drums'])
y_train['other'] =  wav_to_spectrogram_clips(train_sample['other'])

In [None]:
# separator model
separator = ConvDenoisingUnet(1025, 100, (3, 3))
model = separator.get_model()
model.summary()


# BEGIN TRAINING
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
              loss={'vocals': tf.keras.losses.MeanSquaredError(),
                    'bass': tf.keras.losses.MeanSquaredError(),
                    'drums': tf.keras.losses.MeanSquaredError(),
                    'other': tf.keras.losses.MeanSquaredError()})

history = model.fit(x_train, y_train,
                    batch_size=1,
                    epochs=50,
                    verbose=2)

In [None]:
pred = model.predict(wav_to_spectrogram_clips(train_sample['mix']))
pred_vocal = np.squeeze(pred[0], axis=-1)
print(pred_vocal.shape)

In [None]:
separated_vocals = rebuild_audio_from_spectro_clips(pred_vocal)
ipd.Audio(separated_vocals, rate=44100)

In [None]:
reconstructed_vocal = rebuild_audio_from_spectro_clips(wav_to_spectrogram_clips(train_sample['vocals']))
ipd.Audio(train_sample['vocals'], rate=44100)