In [1]:
import matplotlib.pyplot as plt
import numpy as np
import IPython
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.layers as tfkl
import tensorflow_io as tfio
import functools
from pedalboard import load_plugin
from ReverberatorEstimator import layers, loss
import warnings
warnings.filterwarnings('ignore')
import time
import os
import librosa.display
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [2]:
sample_rate = 48000
num_params = 28

In [3]:
target_audio = tfio.audio.AudioIOTensor("Dataset/Wet/Snap.wav")
target_audio = target_audio.to_tensor()
target_audio = tf.slice(target_audio, begin=[0,0], size=[-1,1])
target_audio = tf.cast(target_audio, tf.float32) / 32768.0
target_audio = tf.squeeze(target_audio)
target_audio = tf.reshape(target_audio,(1, 96000))
input_audio = tfio.audio.AudioIOTensor("Dataset/Dry/Snap.wav")
input_audio = input_audio.to_tensor()
input_audio = tf.slice(input_audio, begin=[0,0], size=[-1,1])
input_audio = tf.cast(input_audio, tf.float32) / 32768.0
input_audio = tf.squeeze(input_audio)
input_audio = tf.reshape(input_audio,(1, 96000))

In [4]:
dataset_path = os.path.abspath("./Dataset")

dry_files = []

for f in range(8):
    audio = tfio.audio.AudioIOTensor(dataset_path + "/Dry/Snap.wav")
    audio = audio.to_tensor()
    audio = tf.slice(audio, begin=[0,0], size=[-1,1])
    audio = tf.cast(audio, tf.float32) / 32768.0
    audio = tf.squeeze(audio)
    audio = tf.reshape(audio,(96000))
    dry_files.append(audio)
    
x_train = tf.stack(dry_files)

wet_files = []

for f in range(8):
    audio = tfio.audio.AudioIOTensor(dataset_path + "/Wet/Snap.wav")
    audio = audio.to_tensor()
    audio = tf.slice(audio, begin=[0,0], size=[-1,1])
    audio = tf.cast(audio, tf.float32) / 32768.0
    audio = tf.squeeze(audio)
    audio = tf.reshape(audio,(96000))
    wet_files.append(audio)
    
y_train = tf.stack(wet_files)

In [5]:
logmelgram = layers.LogMelgramLayer(1024, 256, 128, sample_rate, 0.0, sample_rate//2, 1e-6)
audio_time = tfkl.Input(shape=(96000,), name="audio_time")

x = logmelgram(audio_time)
x = tfkl.BatchNormalization(name="input_norm")(x)
encoder_model = tfk.applications.MobileNetV2(input_shape=(x.shape[1], x.shape[2], x.shape[3]), alpha=1.0,
                                            include_top=True, weights=None, input_tensor=None, pooling=None,
                                            classes=np.sum(num_params).item(), classifier_activation="sigmoid")

hidden = encoder_model(x)

parameter_model = tfk.models.Model(audio_time, hidden, name="parameter_model")

parameters = parameter_model(audio_time)

vstlayer = layers.VSTProcessor("../Reverberator.vst3", sample_rate)
output = vstlayer([audio_time, parameters])

model = tfk.models.Model(audio_time, output, name="full_model")

spectral_loss = loss.multiScaleSpectralLoss(sr=sample_rate)

optimizer = tfk.optimizers.Adam(learning_rate=0.001)

checkpoint_dir = './training_checkpoints'

model.compile(optimizer=optimizer, loss=spectral_loss, metrics=['mae'], run_eagerly=True)

TypeError: cannot pickle 'VST3Plugin' object

In [None]:
parameter_model.summary()
model.summary()

In [None]:
# Restore from latest checkpoint
model.load_weights(checkpoint_dir)

In [None]:
audio_pre = (model.call(input_audio)).numpy()[0]
old_params = parameter_model(input_audio).numpy()[0]
print(old_params)

In [None]:
model_cp = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir, 
                             monitor='loss', 
                             verbose=1, 
                             save_best_only=True, 
                             mode='min')

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss',
                                  factor=0.2,
                                  patience=10,
                                  cooldown=0,
                                  verbose=1,
                                  mode='min',
                                  min_lr=0.0000016)

In [None]:
start_time = time.time()
history = model.fit(x_train, y_train, verbose=1, epochs=10,
         callbacks=[model_cp,reduce_lr])
print("Training took %d seconds" % (time.time() - start_time))

In [None]:
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
output_audio = model(input_audio)

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(7,10))
ax[0].plot(audio_pre)
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio_pre)), ref=np.max)
img = librosa.display.specshow(D, y_axis='linear', x_axis='time',
                               sr=sample_rate, ax=ax[1])

In [None]:
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15,10))
ax[0,0].plot(output_audio.numpy()[0])
D = librosa.amplitude_to_db(np.abs(librosa.stft(output_audio.numpy()[0])), ref=np.max)
img = librosa.display.specshow(D, y_axis='linear', x_axis='time',
                               sr=sample_rate, ax=ax[1,0])
ax[0,1].plot(target_audio.numpy()[0])
D = librosa.amplitude_to_db(np.abs(librosa.stft(target_audio.numpy()[0])), ref=np.max)
img = librosa.display.specshow(D, y_axis='linear', x_axis='time',
                               sr=sample_rate, ax=ax[1,1])

In [None]:
IPython.display.Audio(output_audio, rate=sample_rate, autoplay=True)

In [None]:
IPython.display.Audio(audio_pre, rate=sample_rate)

In [None]:
IPython.display.Audio(target_audio, rate=sample_rate)

In [None]:
# Dump values to .csv files
np.savetxt("output_audio.csv", output_audio.numpy()[0], delimiter=",")
np.savetxt("target_audio.csv", target_audio.numpy()[0], delimiter=",")

In [None]:
parameters = parameter_model(input_audio)

In [None]:
params = parameters.numpy()[0]
filter_c = 1
for i in range(num_params):
    if i < 4:
        print("b_%i = %f" % (i, params[i]))
    elif i < 8:
        print("c_%i = %f" % (i-4, params[i]))
    else:
        j = (i-8) % 5
        
        if j is 0:
            print("\nFilter %i:" % filter_c)
            filter_c = filter_c + 1
            print("c_hp = %f" % params[i])
        elif j is 1:
            print("c_bp = %f" % params[i])
        elif j is 2:
            print("c_lp = %f" % params[i])
        elif j is 3:
            print("g = %f" % params[i])
        elif j is 4:
            print("R = %f" % params[i])

In [None]:
param_diff = params - old_params
print(param_diff)
plt.stem(param_diff)
plt.ylim([-1, 1])