In [None]:
# TO USE: 
#    1. Upload your input and output wav files to the current directory in Colab
#    2. Edit the USER INPUTS section to point to your wav files, and choose a
#         model name, and number of epochs for training. If you experience 
#         crashing due to low RAM, reduce the "input_size" parameter, or increase
#         the "split_data" parameter.
#    3. Run each section of code. The trained models and output wav files will be 
#         added to the "models" directory.
#
#     Note: Tested on CPU and GPU runtimes.

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import LSTM, Conv1D, Dense, TimeDistributed
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.backend import clear_session
from tensorflow.keras.activations import tanh, elu, relu
from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K
from tensorflow.keras.utils import Sequence

import os
from scipy import signal
from scipy.io import wavfile
import numpy as np
import matplotlib.pyplot as plt
import math
import h5py

In [None]:
# EDIT THIS SECTION FOR USER INPUTS
#
name = 'test'
in_file = 'data/ts9_test1_in_FP32.wav'
out_file = 'data/ts9_test1_out_FP32.wav'
epochs = 1

train_mode = 0     # 0 = speed training, 
                   # 1 = accuracy training 
                   # 2 = extended training

input_size = 150 

if not os.path.exists('models/'+name):
    os.makedirs('models/'+name)
else:
    print("A model with the same name already exists. Please choose a new name.")
    exit

In [None]:
class WindowArray(Sequence):
        
    def __init__(self, x, y, window_len, batch_size=32):
        self.x = x
        self.y = y
        self.window_len = window_len
        self.batch_size = batch_size
        
    def __len__(self):
        return (len(self.x) - self.window_len +1) // self.batch_size
    
    def __getitem__(self, index):
        x_out = np.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])
        y_out = np.stack([self.y[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])
        return x_out, y_out

In [None]:
def pre_emphasis_filter(x, coeff=0.95):
    return tf.concat([x[:, 0:1, :], x[:, 1:, :] - coeff*x[:, :-1, :]], axis=1)
    
def error_to_signal(y_true, y_pred): 
    """
    Error to signal ratio with pre-emphasis filter:
    """
    y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)
    return K.sum(tf.pow(y_true - y_pred, 2), axis=1) / (K.sum(tf.pow(y_true, 2), axis=1) + 1e-10)
    
def save_wav(name, data):
    wavfile.write(name, 44100, data.flatten().astype(np.float32))

def normalize(data):
    data_max = max(data)
    data_min = min(data)
    data_norm = max(data_max,abs(data_min))
    return data / data_norm


'''This is a similar Tensorflow/Keras implementation of the LSTM model from the paper:
    "Real-Time Guitar Amplifier Emulation with Deep Learning"
    https://www.mdpi.com/2076-3417/10/3/766/htm

    Uses a stack of two 1-D Convolutional layers, followed by LSTM, followed by 
    a Dense (fully connected) layer. Three preset training modes are available, 
    with further customization by editing the code. A Sequential tf.keras model 
    is implemented here.

    Note: RAM may be a limiting factor for the parameter "input_size". The wav data
      is preprocessed and stored in RAM, which improves training speed but quickly runs out
      if using a large number for "input_size".  Reduce this if you are experiencing
      RAM issues. 
    
    --training_mode=0   Speed training (default)
    --training_mode=1   Accuracy training
    --training_mode=2   Extended training (set max_epochs as desired, for example 50+)
'''

batch_size = 4096
test_size = 0.2

if train_mode == 0:         # Speed Training
    learning_rate = 3e-4
    conv1d_strides = 12    
    conv1d_filters = 16
    hidden_units = 36
elif train_mode == 1:       # Accuracy Training (~10x longer than Speed Training)
    learning_rate = 3e-4 
    conv1d_strides = 4
    conv1d_filters = 36
    hidden_units= 64
else:                       # Extended Training (~60x longer than Accuracy Training)
    learning_rate = 3e-4
    conv1d_strides = 3
    conv1d_filters = 36
    hidden_units= 96


# Create Sequential Model ###########################################
clear_session()
model = Sequential()
model.add(Conv1D(conv1d_filters, 12, activation=None, padding='same',input_shape=(input_size,1)))
model.add(Conv1D(conv1d_filters, 12, activation=None, padding='same'))
model.add(LSTM(hidden_units, return_sequences=True))
model.add(TimeDistributed(Dense(1, activation=None)))
model.compile(optimizer=Adam(learning_rate=learning_rate), loss=error_to_signal, metrics=['mae', 'mse', error_to_signal])
model.summary()

# Load and Preprocess Data ###########################################
in_rate, in_data = wavfile.read(in_file)
out_rate, out_data = wavfile.read(out_file)

X_all = in_data.astype(np.float32).flatten()  
X_all = normalize(X_all).reshape(len(X_all),1)   
y_all = out_data.astype(np.float32).flatten() 
y_all = normalize(y_all).reshape(len(y_all),1)

train_examples = int(len(X_all)*0.8)
train_arr = WindowArray(X_all[:train_examples], y_all[:train_examples], input_size, batch_size=batch_size)
val_arr = WindowArray(X_all[train_examples:], y_all[train_examples:], input_size, batch_size=batch_size)

# Train Model ###################################################
history = model.fit(train_arr, validation_data=val_arr, epochs=1, shuffle=True)    
model.save('models/'+name+'/'+name+'.h5')

# Run Prediction #################################################
print("Running prediction..")

# Get the last 20% of the wav data to run prediction and plot results
y_the_rest, y_last_part = np.split(y_all, [int(len(y_all)*.8)])
x_the_rest, x_last_part = np.split(X_all, [int(len(X_all)*.8)])
y_test = y_last_part[input_size-1:] 
test_arr = WindowArray(x_last_part, y_last_part, input_size, batch_size = batch_size)

prediction = model.predict(test_arr)

# The full prediction has a lot of redundant frames, 
# so we only pick the last frame from each sample in the window
save_wav('models/'+name+'/y_pred.wav', prediction[:, -1, :])
save_wav('models/'+name+'/x_test.wav', x_last_part)
save_wav('models/'+name+'/y_test.wav', y_test)

# Add additional data to the saved model (like input_size)
filename = 'models/'+name+'/'+name+'.h5'
f = h5py.File(filename, 'a')
grp = f.create_group("info")
dset = grp.create_dataset("input_size", (1,), dtype='int16')
dset[0] = input_size
f.close()