Simple recurrent neural network (RNN) implementation in Keras using LSTM (long short-term memory) units to identify time of occurence of some events in temporal data based on the wavelet spectrogram of the data

In [None]:
%matplotlib inline

In [None]:
import os
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import tensorflow
from tensorflow import keras
import h5py # TODO: use tensorflow.keras.utils.HDF5Matrix

In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
np.random.seed(1234)

# Load training/validation data

In [None]:
# load training/validation datas with labels
f = h5py.File(os.path.join('data', 'processed', 'data.h5'), 'r')
X0 = f['X0'][:]
X = f['X'][:]
Y = f['Y'][:]
labels = f['labels'][:]
waveletfreqs = f['waveletfreqs'][:]
f.close()

# center raw data
X0 = (X0.T - X0.mean(axis=-1)).T

# randomly permute order of datasets
permutation = np.random.permutation(np.arange(X0.shape[0]))
X0 = X0[permutation]
X = X[permutation]
Y = Y[permutation]
labels = labels[permutation]

In [None]:
# some needed parameters
n_samples = X0.shape[0]
Fs = 2500 # Hz, sampling freq
time = np.arange(X0.shape[1]) / Fs
n_val_samples = 50 # number of validation samples

In [None]:
# test plot
plt.figure()
plt.plot(time, X0[0, :], label='raw data')
plt.plot(time, Y[0, :, 0], label='label (y)')
plt.legend()
plt.xlabel('t (s)')

In [None]:
# plot all labels and raw data matrices
fig, axes = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(12, 12))
axes[0].pcolormesh(time, np.arange(n_samples), Y[:, :, 0])
axes[0].set_ylabel('#')
axes[0].set_title('labels (y)')
axes[1].pcolormesh(time, np.arange(n_samples), X0)
axes[1].set_ylabel('#')
axes[1].set_xlabel('t (s)')
axes[1].set_title('raw data')
for ax in axes:
    ax.axis(ax.axis('tight'))

In [None]:
# plot wavelet spectrograms vs. labels and raw data for some samples
for i in range(3):
    gs = GridSpec(4, 1)
    fig = plt.figure(figsize=(12, 8))
    ax0 = fig.add_subplot(gs[0, 0])
    ax0.plot(time, X0[i, ], label='raw data')
    ax0.plot(time, Y[i, :, 0], label='label (y)' )
    ax0.legend(ncol=2)
    ax0.axis(ax0.axis('tight'))
    ax0.set_title('label and raw data')
    plt.setp(ax0.get_xticklabels(), visible=False)
    
    ax1 = fig.add_subplot(gs[1:, 0], sharex=ax0)
    im = ax1.pcolormesh(time, waveletfreqs, X[i, ].T, vmin=0, vmax=X.std()*2)
    ax1.axis(ax1.axis('tight'))
    ax1.set_ylabel('f (Hz)')
    ax1.set_xlabel('t (s)')

# Set up recurrent neural network

In [None]:
def generate_model(input_shape, lr=0.01, dropout_rate=0.2, layer_sizes=[5, 5, 5], ):
    keras.backend.clear_session()

    # input layer
    inputs = keras.layers.Input(shape=input_shape)
    
    # conv layer
    x = keras.layers.Conv1D(layer_sizes[0], 
                            kernel_size=5, strides=1, 
                            padding='same'
                           )(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    # LSTM layer 1
    x = keras.layers.LSTM(layer_sizes[1], return_sequences=True)(x)
    x = keras.layers.BatchNormalization()(x)  
    x = keras.layers.Dropout(dropout_rate)(x)
    
    # LSTM layer 2
    x = keras.layers.LSTM(layer_sizes[2], return_sequences=True)(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Dropout(dropout_rate)(x)
        
    # dense output layer
    predictions = keras.layers.TimeDistributed(
        keras.layers.Dense(1, activation='sigmoid'))(x)
    
    # Define model
    model = keras.models.Model(inputs=inputs, outputs=predictions)

    opt = keras.optimizers.Adam(lr=lr)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy', 'mse'])

    return model

In [None]:
model = generate_model(input_shape=(None, X.shape[2]))

In [None]:
model.summary()

In [None]:
history = model.fit(X[:-n_val_samples, :, :], Y[:-n_val_samples, :, :], 
                    batch_size=20, epochs=20, 
                    validation_data=(X[-n_val_samples:, :, :], Y[-n_val_samples:, :, :]))

In [None]:
plt.figure()
plt.semilogy(history.history['loss'], '-o', label='loss')
#plt.plot(history.history['accuracy'], '-o', label='accuracy')
#plt.plot(history.history['val_accuracy'], '-o', label='val_accuracy')
plt.semilogy(history.history['mse'], '-o', label='mse')
plt.semilogy(history.history['val_mse'], '-o', label='val_mse')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('training/validation loss')

In [None]:
# visualize predictions on some samples from the validation set
#n_val_samples = 3
n_plots = 3
X_val = X[-n_val_samples:, ]
Y_val = Y[-n_val_samples:, ]

Y_pred = model.predict(X_val)

# compare prediction to ground truth
fig, axes = plt.subplots(n_plots, 2, figsize=(12, 12), 
                         sharex=True, sharey='col')
for i in range(n_plots):
    axes[i, 0].pcolormesh(time, waveletfreqs, X_val[i].T, vmin=0, vmax=X.std()*2)
    axes[i, 1].plot(time, Y_val[i], label='$y(t)$')
    axes[i, 1].plot(time, Y_pred[i], label='$\hat{y}(t)$')
    if i == 0:
        axes[i, 1].legend()
        axes[i, 0].set_title('$X(t)$')
        axes[i, 1].set_title('$y(t)$ vs $\hat{y}(t)$')
    axes[i, 0].set_ylabel('f (Hz)')
    axes[i, 1].set_ylabel('probability')
axes[i, 0].set_xlabel('$t$ (s)')
axes[i, 1].set_xlabel('$t$ (s)')