# RippleNet_training_bidirectional
Training of simple bidirectional recurrent neural network (RNN) implementation in `tensorflow.keras` using LSTM (long short-term memory) units to identify time of occurence of sharp wave ripple (SPW-R) events in temporal data.

Author: Espen Hagen (<https://github.com/espenhgn>)

LICENSE: <https://github.com/CINPLA/RippleNet/blob/master/LICENSE>

In [None]:
# allow running on Google Colab for training using Google Drive for file access
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    %cd gdrive/My\ Drive/Colab\ Notebooks/RippleNet
    %tensorflow_version 2.x
except:
    pass

In [None]:
%matplotlib inline

In [None]:
import os
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib import colors
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import plot_model
import ripplenet.models
import h5py
import pickle
import random

In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

In [None]:
print(tf.__version__)
print(tf.test.gpu_device_name())
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
# set random seeds with some additional environment variables to ensure deterministic output
random_seed = 789
os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['PYTHONHASHSEED']=str(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

In [None]:
# select dataset (may have generated different sets.)
dataset_index = 0 

# Load training/validation data

In [None]:
# select species for training/validation data (mouse, rat or both)
mouse = True
rat = True

In [None]:
# output destination
output_folder = 'trained_networks'
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)

# prefix for trained network files (training loss/MSE, weights, `best' weights)
rnn_prefix = 'ripplenet_bidirectional'

In [None]:
if mouse:
    # training and validation files
    f_name_train = 'train_{:02}.h5'
    f_name_val = 'validation_{:02}.h5'

    # training data
    f = h5py.File(os.path.join('data',  
                               f_name_train.format(dataset_index)), 
                  'r')
    X_train = np.expand_dims(f['X0'][:], -1)
    Y_train = f['Y'][:]
    f.close()

    # validation data
    f = h5py.File(os.path.join('data', 
                               f_name_val.format(dataset_index)), 
                  'r')
    X_val = np.expand_dims(f['X0'][:], -1)
    Y_val = f['Y'][:]
    f.close()

    # load some data for plotting
    f = h5py.File(os.path.join('data',
                               f_name_val.format(dataset_index)), 'r')
    X0 = f['X0'][:]
    X1 = f['X1'][:]
    S = f['S'][:]
    Y = f['Y'][:]
    S_freqs = f['S_freqs'][:]
    f.close()

In [None]:
# Add rat training/validation data to sets
if rat and mouse:
    # rat 
    f_name_train = 'train_tingley_{:02}.h5'
    f_name_val = 'validation_tingley_{:02}.h5'

    # training data
    f = h5py.File(os.path.join('data', 
                            f_name_train.format(dataset_index)), 
                'r')
    X_train = np.concatenate((X_train, np.expand_dims(f['X0'][:], -1)))
    Y_train = np.concatenate((Y_train, f['Y'][:]))
    f.close()

    # validation data
    f = h5py.File(os.path.join('data',  
                            f_name_val.format(dataset_index)), 
                'r')
    X_val = np.concatenate((X_val, np.expand_dims(f['X0'][:], -1)))
    Y_val = np.concatenate((Y_val, f['Y'][:]))
    f.close()

    # load some data for plotting
    f = h5py.File(os.path.join('data',
                            f_name_val.format(dataset_index)), 'r')
    X0 = np.concatenate((X0, f['X0'][:]))
    X1 = np.concatenate((X1, f['X1'][:]))
    S = np.concatenate((S, f['S'][:]))
    Y = np.concatenate((Y, f['Y'][:]))
    f.close()

In [None]:
if rat and not mouse:
    # rat 
    f_name_train = 'train_tingley_{:02}.h5'
    f_name_val = 'validation_tingley_{:02}.h5'

    # training data
    f = h5py.File(os.path.join('data',  
                               f_name_train.format(dataset_index)), 
                  'r')
    X_train = np.expand_dims(f['X0'][:], -1)
    Y_train = f['Y'][:]
    f.close()

    # validation data
    f = h5py.File(os.path.join('data', 
                               f_name_val.format(dataset_index)), 
                  'r')
    X_val = np.expand_dims(f['X0'][:], -1)
    Y_val = f['Y'][:]
    f.close()

    # load some data for plotting
    f = h5py.File(os.path.join('data', 
                               f_name_val.format(dataset_index)), 'r')
    X0 = f['X0'][:]
    X1 = f['X1'][:]
    S = f['S'][:]
    Y = f['Y'][:]
    S_freqs = f['S_freqs'][:]
    f.close()

In [None]:
# needed parameters
Fs = 1250 # Hz, sampling freq
time = np.arange(X0.shape[1]) / Fs

# center raw data
X0 = (X0.T - X0.mean(axis=-1)).T

# total number of samples
n_samples = X0.shape[0]

In [None]:
# plot all labels and raw data matrices
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(12, 12))
axes[0].pcolormesh(time, np.arange(n_samples), Y[:, :, 0])
axes[0].set_ylabel('#')
axes[0].set_title('labels (y)')
axes[1].pcolormesh(time, np.arange(n_samples), X0, vmin=-X0.std()*3, vmax=X0.std()*3)
axes[1].set_ylabel('#')
axes[1].set_xlabel('t (s)')
axes[1].set_title('raw data (X)')
for ax in axes:
    ax.axis(ax.axis('tight'))

In [None]:
# plot wavelet spectrograms vs. labels and raw data for some samples
for i in range(5):
    gs = GridSpec(2, 1)
    fig = plt.figure(figsize=(12, 6))
    ax0 = fig.add_subplot(gs[0, 0])
    ax0.plot(time, X0[i, ], label='$X(t)$')
    ax0.plot(time, X1[i, ], label=r'$\phi_\mathrm{bp}(t)$')
    ax0.plot(time, Y[i, :, 0], label='label ($y$)' )
    ax0.legend(ncol=2)
    ax0.axis(ax0.axis('tight'))
    ax0.set_title('label, raw data and spectrograms')
    plt.setp(ax0.get_xticklabels(), visible=False)
    
    ax1 = fig.add_subplot(gs[1:, 0], sharex=ax0)
    vmin, vmax = np.exp(np.percentile(np.log(S), [1, 99]))
    im = ax1.pcolormesh(time, S_freqs, S[i, ].T, norm=colors.LogNorm(vmin=vmin, vmax=vmax),
                        cmap='inferno')
    ax1.axis(ax1.axis('tight'))
    ax1.set_ylabel('$f$ (Hz)')
    ax1.set_xlabel('$t$ (s)')

# Set up recurrent neural network

In [None]:
model = ripplenet.models.get_bidirectional_LSTM_model(input_shape=(None, X_train.shape[2]), 
                                                      layer_sizes=[20, 10, 6, 6],
                                                      seed=random_seed+1)

In [None]:
model.summary()

In [None]:
# plot_model(model, show_shapes=True, expand_nested=True)

In [None]:
# model checkpoints when validation mse improves
filepath = os.path.join(output_folder, '{}_best_random_seed{}.h5'.format(rnn_prefix, random_seed))
checkpoint_best = keras.callbacks.ModelCheckpoint(filepath, monitor='val_mse', 
                                             verbose=1, save_best_only=True, 
                                             mode='min')
callback_hist = keras.callbacks.CSVLogger(os.path.join(output_folder, '{}_history_random_seed{}.csv'.format(rnn_prefix, random_seed)))
callbacks_list = [checkpoint_best, callback_hist]

In [None]:
# train model
history = model.fit(X_train, Y_train, 
                    batch_size=20, 
                    epochs=50, 
                    callbacks=callbacks_list,
                    validation_data=(X_val, Y_val))

In [None]:
# save history to a pickle so we can load it later
with open(os.path.join(output_folder, '{}_history_random_seed{}.pkl'.format(rnn_prefix, random_seed)), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
plt.figure(figsize=(12, 12))
plt.semilogy(history.history['loss'], '-o', label='loss')
plt.semilogy(history.history['val_loss'], '-o', label='val_loss')
plt.semilogy(history.history['mse'], '-o', label='mse')
plt.semilogy(history.history['val_mse'], '-o', label='val_mse')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('MSE')
plt.title('training/validation MSE')

In [None]:
# Save the trained model
model.save(os.path.join(output_folder, '{}_random_seed{}.h5'.format(rnn_prefix, random_seed)))