### tsdGAN: A generative adversarial network approach for removing electrocardiographic interference from electromyographic signals 
Lucas Haberkamp<sup>1,2</sup>, Charles A. Weisenbach<sup>1</sup>, Peter Le<sup>3</sup>  
<sup>1</sup>Naval Medical Research Unit Dayton, Wright-Patterson Air Force Base, OH, USA   
<sup>2</sup>Leidos, Reston, VA, USA   
<sup>3</sup>Air Force Research Laboratory, 711th Human Performance Wing, Wright-Patterson Air Force Base, OH, USA

#### This notebook is used to train the tsdGAN deep learning model on experimental data

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K
from sklearn import utils
from tensorflow.keras.activations import gelu
from sklearn.preprocessing import StandardScaler
import pandas as pd
from scipy import signal
import time

Define Butterworth filter function

In [None]:
def butterfilter(x, Fc, Fs, ftype):
    Wn = np.asarray(Fc)/np.asarray(Fs/2)
    b, a = signal.butter(2, Wn, ftype)
    return signal.filtfilt(b,a,x)

Load in training & validation datasets

In [None]:
# Experimental dataset
x_train = np.load('../../Data/Preprocessed Data/Experimental/x_real.npy')
y_train = np.load('../../Data/Preprocessed Data/Experimental/y_real.npy')
label_train = np.load('../../Data/Preprocessed Data/Experimental/label_real.npy')

# Reshape experimental data so that it has only 1 feature
x_train = np.reshape(x_train.transpose(2,0,1),(-1, x_train.shape[1],1))
y_train = np.reshape(y_train.transpose(2,0,1),(-1, y_train.shape[1],1))
label_train = np.reshape(label_train.transpose(2,0,1),(-1, label_train.shape[1],1))

# Load an experimental trial used for validation 
val_df = pd.read_csv('../../Data/Raw TS EMG Data/Validation/TrunkStability_DS_S20_EMG_Raw_17.csv', header=13).iloc[:,3:]
val_df = val_df - val_df.mean()
val_df = val_df.apply(lambda x: butterfilter(x, Fc=500, Fs=1920, ftype='low'))
val_df.head()

Mask the ground truth data so that it does not contain QRS locations

In [None]:
y_target = []
for i in range(x_train.shape[0]):
    y_target.append(x_train[i] * label_train[i])
y_target = np.array(y_target)

In [None]:
print("x_shape:", x_train.shape)
print("y_shape:", y_train.shape)
print("y_target:", y_target.shape)
print("label_shape:", label_train.shape)

In [None]:
# Confirm the synthetic data is unpaired
for i in range(5):
  plt.plot(x_train[i])
  plt.plot(y_train[i])
  plt.title("Unpaired Synthetic Data: " + str(i))
  plt.xlabel("Samples")
  plt.ylabel("Normalized Amplitude")
  plt.show()

In [None]:
# Confirm the QRS complexes are masked 
for i in range(5):
    plt.plot(y_target[i])
    plt.title("Masked Target Data: " + str(i))
    plt.xlabel("Samples")
    plt.ylabel("Normalized Amplitude")
    plt.show()

In [None]:
# Plot QRS complex locations over the contaminated data
for i in range(10):
    plt.plot(x_train[i])
    plt.plot(label_train[i])
    plt.title('QRS Complex Label: ' + str(i))
    plt.xlabel("Samples")
    plt.ylabel("Normalized Amplitude")
    plt.show()

Define the TCN separation module

In [None]:
def tcn_block(inputs, filters, dilation_rates):
  res = LayerNormalization()(inputs)
  for dilation in dilation_rates:
    x = SeparableConv1D(filters, dilation_rate=dilation,
                        kernel_size=3, padding='same', activation=gelu)(res)

    x = LayerNormalization()(x)
    x = SeparableConv1D(filters, dilation_rate=dilation,
                        kernel_size=3, padding='same', activation=gelu)(x)
    res = LayerNormalization()(Add()([x, res]))
  return res

Define the generator model

In [None]:
def build_generator():
  dilation_rates = [1,2,4,8,16,32,64]

  inputs = Input(shape=(None, 1))

  x = Conv1D(filters=32, kernel_size=7, padding='same', activation=gelu)(inputs)
  x = LayerNormalization()(x)

  x = Conv1D(filters=64, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  x = Conv1D(filters=128, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  x = Conv1D(filters=128, kernel_size=1, padding='same', activation=gelu)(x)

  sep = tcn_block(x, filters=128, dilation_rates=dilation_rates)
  sep = Conv1D(filters=128, kernel_size=1, padding='same', activation='sigmoid')(sep)

  mask = Multiply()([x, sep])
  mask = LayerNormalization()(mask)

  x = Conv1DTranspose(filters=64, strides=2, kernel_size=3, padding='same', activation=gelu)(mask)
  x = LayerNormalization()(x)

  x = Conv1DTranspose(filters=32, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  outputs = Conv1D(1, kernel_size=7, padding='same')(x)
  return Model(inputs, outputs, name='generator')

model = build_generator()
model.summary()
del model

In [None]:
def build_discriminator():
  dilation_rates = [1,2,4,8,16,32,64,128]

  inputs = Input(shape=(None, 1))

  noise = GaussianNoise(0.1)(inputs)

  x = Conv1D(filters=32, kernel_size=7, padding='same', activation=gelu)(noise)
  x = LayerNormalization()(x)

  x = Conv1D(filters=64, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  x = Conv1D(filters=64, kernel_size=1, padding='same', activation=gelu)(x)

  x = tcn_block(x, filters=64, dilation_rates=dilation_rates)

  x = Conv1D(filters=64, kernel_size=1, padding='same', activation=gelu)(x)
  x = GlobalAveragePooling1D()(x)

  outputs = Dense(1)(x)

  return Model(inputs, outputs, name='discriminator')

model = build_discriminator()
model.summary()
del model

Function to create & compile the generator & discriminator neural networks

In [None]:
def create_gans():

    # create building blocks
    discriminator = build_discriminator()
    generator = build_generator()

    # compile discriminators while they're set to trainable
    optimizer = Adam(learning_rate=2e-4, beta_1=0.5)
    discriminator.compile(optimizer=optimizer, loss='mse')

    discriminator.trainable = False

    # create the first GAN architecture
    input_seq = Input(shape=(None,1))
    input_weights = Input(shape=(None,1))

    output_seq_b = generator(input_seq)
    global_out = discriminator(output_seq_b)
    mask_output_seq_b = Multiply(name='reconstruction')([output_seq_b, input_weights])

    GAN = Model([input_seq, input_weights], [global_out, mask_output_seq_b], name='GAN')

    optimizer = Adam(learning_rate=1e-4, beta_1=0.5)
    loss = ['mse', 'mae']
    loss_weights = [1, 10]

    GAN.compile(optimizer=optimizer, loss=loss, loss_weights=loss_weights)

    return GAN, generator, discriminator

Create the models

In [None]:
GAN, generator, discriminator = create_gans()

Function to evaluate tsdGAN performance on experimental data

In [None]:
# Initialize a scaler object
scaler = StandardScaler()

def tsdGAN_filter(raw_data, window_size=2880):
    data = raw_data.values

    step_size = window_size // 3
    x_test = []

    # Generate windows
    for i in range(window_size, data.shape[0], step_size):
        x_test.append(data[i-window_size:i])
    
    # Handle the last window if there's remaining data
    last_data_start = i
    if last_data_start < data.shape[0]:
        x_test.append(data[-window_size:])
    
    x_test = np.array(x_test)
    
    # Scale the data
    x_test_scaled = []
    for window in x_test:
        x_test_scaled.append(scaler.fit_transform(window.reshape(-1, 1)))
    x_test_scaled = np.array(x_test_scaled)

    # Predict on scaled data
    y_pred = model.predict(x_test_scaled, batch_size=32)

    # Inverse transform the predictions
    y_pred_inverse = []
    for i in range(y_pred.shape[0]):
        scaler.fit(x_test[i].reshape(-1, 1))
        y_pred_inverse.append(scaler.inverse_transform(y_pred[i]))
    y_pred_inverse = np.array(y_pred_inverse)

    # Prepare the final array
    final_output = [y_pred_inverse[0][:1920]]  # Start with the first segment's initial part

    # Add middle segments, adjust if necessary to include more of each segment
    for seg in y_pred_inverse[1:-1]:
        final_output.append(seg[step_size:-step_size])

    # Handle the last segment to match the remaining data length
    final_len = data.shape[0] - last_data_start + step_size
    final_segment = y_pred_inverse[-1]
    final_output.append(final_segment[-final_len:])

    # Concatenate the adjusted segments
    final_output = np.concatenate(final_output, axis=0)

    return final_output.ravel()

In [None]:
# Reset states generated by Keras
tf.keras.backend.clear_session()

Training Loop

In [None]:
# Core training loop, loads batch of sequences, generators and discriminators on batch
batch_size = 64

# Calculate number of batches
num_batches = x_train.shape[0] // batch_size

# total epochs 
epochs = 150

epoch_count = 1 # index of current epoch

# Empty arrays for storing loss
real_loss_all = []
fake_loss_all = []
gan_loss_all = []

while epoch_count <= epochs:
    start_time = time.time() # timer for the epoch to complete

    # shuffle training data
    x_train, y_train, y_target, label_train = utils.shuffle(x_train, y_train, y_target, label_train)

    print('EPOCH: ' + str(epoch_count))

    # Get a batch of data
    batch_index = 0  # index of the current batch

    # Store per batch loss within the epoch
    real_loss = []
    fake_loss = []
    gan_loss = []

    for i in range(num_batches):
        # Select a subset of the data corresponding to the indices of the batch
        seq_a_batch = x_train[batch_index:batch_index+batch_size]
        seq_b_batch = y_train[batch_index:batch_index+batch_size]
        masked_seq_a_batch = y_target[batch_index:batch_index+batch_size]
        labels_batch = label_train[batch_index:batch_index+batch_size]

        batch_index += batch_size

        target_batch = np.ones([len(seq_a_batch),1])
        fake_batch = np.zeros([len(seq_a_batch),1]) 

        # Train discriminator on real data
        loss = discriminator.train_on_batch(seq_b_batch, target_batch)
        real_loss.append(np.expand_dims(np.array(loss), axis=0))

        # Use the generator to create predictions for the discrminator
        seq_b_batch_fake = generator.predict(seq_a_batch, verbose=0)

        # Train discriminator on fake data
        loss = discriminator.train_on_batch(seq_b_batch_fake, fake_batch)
        fake_loss.append(np.expand_dims(np.array(loss), axis=0))
        
        # Fit each generator
        loss = GAN.train_on_batch([seq_a_batch, labels_batch], [target_batch, masked_seq_a_batch])
        gan_loss.append(np.expand_dims(np.array(loss), axis=0))

    gan_loss = np.mean(np.concatenate(gan_loss, axis=0), axis=0)
    real_loss = np.mean(np.concatenate(real_loss, axis=0), axis=0)
    fake_loss = np.mean(np.concatenate(fake_loss, axis=0), axis=0)

    print("--- %s seconds ---" % np.round((time.time() - start_time),2))
    print("Discriminator Real Data:", discriminator.metrics_names[0], "=", real_loss)
    print("Discriminator Fake Data:", discriminator.metrics_names[0], "=", fake_loss)
    print("GAN:", GAN.metrics_names[0], "=", gan_loss[0], "-", GAN.metrics_names[1], "=", gan_loss[1], "-", GAN.metrics_names[2], "=", gan_loss[2])
    
    gan_loss_all.append(np.expand_dims(gan_loss, axis=0))
    fake_loss_all.append(np.expand_dims(fake_loss, axis=0))
    real_loss_all.append(np.expand_dims(real_loss, axis=0))

    # Validate performance every 10 epochs
    if (epoch_count % 10 == 0):

        # Save models
        MODEL_PATH = '../../Models/Experimental/discriminator_epoch' + str(epoch_count) + '.h5'
        discriminator.save(MODEL_PATH)

        MODEL_PATH = '../../Models/Experimental/generator_epoch' + str(epoch_count) + '.h5'
        generator.save(MODEL_PATH)

        # Save current training performance
        np.save('../../Training Performance/Experimental/real_loss.npy', np.concatenate(real_loss_all, axis=0))
        np.save('../../Training Performance/Experimental/fake_loss.npy', np.concatenate(fake_loss_all, axis=0))
        np.save('../../Training Performance/Experimental/gan_loss.npy', np.concatenate(gan_loss_all, axis=0))

        print("\nEpoch " + str(epoch_count) + " Prediction")

        raw_df = val_df.copy()
        filt_df = raw_df.apply(lambda x: tsdGAN_filter(x))

        # Define a new set of more descriptive names for each EMG sensor
        better_muscle_names = ['Right Erector Spinae',
            'Left Erector Spinae',
            'Right Internal Oblique',
            'Left Internal Oblique',
            'Right Latissimus Dorsi',
            'Left Latissimus Dorsi',
            'Right Rectus Abdominis',
            'Left Rectus Abdominis',
            'Right External Oblique',
            'Left External Oblique']

        # Define time-axis
        t = np.arange(len(filt_df))/1920

        # Creating a 5x2 grid of subplots
        fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(12, 12))
        fig.subplots_adjust(hspace=0.5, wspace=0.3)  # Adjust space between plots

        c = 0  # Counter for iterating through DataFrame columns
        for i in range(5):
            for j in range(2):
                ax = axs[i, j]
                # Plotting raw and filtered data
                ax.plot(t, raw_df.iloc[:, c], label="Raw", color="black", linewidth=1.5)
                ax.plot(t, filt_df.iloc[:, c], label="tsdGAN", color="dodgerblue", linewidth=1.25)

                # Setting titles, labels, and grid only for specific subplots
                if i == 4:
                    ax.set_xlabel("Time (s)")
                if j == 0:
                    ax.set_ylabel("Amplitude (mV)")

                ax.set_title(better_muscle_names[c])

                if (i, j) == (0, 0):  # Showing legend only in the first subplot
                    ax.legend()

                c += 1

        plt.tight_layout()  # Adjust overall layout
        plt.show()

    epoch_count += 1