### tsdGAN: A generative adversarial network approach for removing electrocardiographic interference from electromyographic signals 
Lucas Haberkamp<sup>1,2</sup>, Charles A. Weisenbach<sup>1</sup>, Peter Le<sup>3</sup>  
<sup>1</sup>Naval Medical Research Unit Dayton, Wright-Patterson Air Force Base, OH, USA   
<sup>2</sup>Leidos, Reston, VA, USA   
<sup>3</sup>Air Force Research Laboratory, 711th Human Performance Wing, Wright-Patterson Air Force Base, OH, USA

#### This notebook is used to train the tsdGAN deep learning model on synthetic data

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K
from sklearn import utils
from tensorflow.keras.activations import gelu
import pandas as pd
from scipy import signal, spatial
import time
import pickle

Load in training & validation datasets

In [None]:
# Synthetic dataset
x_train = np.load('../../Data/Preprocessed Data/Synthetic/Training/x_syn.npy')
y_train = np.load('../../Data/Preprocessed Data/Synthetic/Training/y_syn.npy')
label_train = np.load('../../Data/Preprocessed Data/Synthetic/Training/label_syn.npy')

# Function to convert validation dictionary to array
def dict2arr(dictionary):
  x = []
  for i, value in enumerate(dictionary.values()):
    x.append(value)
  return np.vstack(x)

with open('../../Data/Preprocessed Data/Synthetic/Validation/x_val.pkl', 'rb') as f:
  x_val = pickle.load(f)

x_val = dict2arr(x_val)
print(x_val.shape)

with open('../../Data/Preprocessed Data/Synthetic/Validation/y_val.pkl', 'rb') as f:
  y_val = pickle.load(f)

y_val = dict2arr(y_val)

Mask the ground truth data so that it does not contain QRS locations

In [None]:
y_target = []
for i in range(x_train.shape[0]):
    y_target.append(x_train[i] * label_train[i])
y_target = np.array(y_target)

In [None]:
print("x_shape:", x_train.shape)
print("y_shape:", y_train.shape)
print("y_target:", y_target.shape)
print("label_shape:", label_train.shape)

In [None]:
# Confirm the synthetic data is unpaired
for i in range(5):
  plt.plot(x_train[i])
  plt.plot(y_train[i])
  plt.title("Unpaired Synthetic Data: " + str(i))
  plt.xlabel("Samples")
  plt.ylabel("Amplitude ($\mu$V)")
  plt.show()

In [None]:
# Confirm the QRS complexes are masked 
for i in range(5):
    plt.plot(y_target[i])
    plt.title("Masked Target Data: " + str(i))
    plt.xlabel("Samples")
    plt.ylabel("Amplitude ($\mu$V)")
    plt.show()

In [None]:
# Plot QRS complex locations over the contaminated data
for i in range(10):
    plt.plot(x_train[i])
    plt.plot(label_train[i])
    plt.title('QRS Complex Label: ' + str(i))
    plt.xlabel("Samples")
    plt.ylabel("Amplitude ($\mu$V)")
    plt.show()

Define the TCN separation module

In [None]:
def tcn_block(inputs, filters, dilation_rates):
  res = LayerNormalization()(inputs)
  for dilation in dilation_rates:
    x = SeparableConv1D(filters, dilation_rate=dilation,
                        kernel_size=3, padding='same', activation=gelu)(res)

    x = LayerNormalization()(x)
    x = SeparableConv1D(filters, dilation_rate=dilation,
                        kernel_size=3, padding='same', activation=gelu)(x)
    res = LayerNormalization()(Add()([x, res]))
  return res

Define the generator model

In [None]:
def build_generator():
  dilation_rates = [1,2,4,8,16,32,64]

  inputs = Input(shape=(None, 1))

  x = Conv1D(filters=32, kernel_size=7, padding='same', activation=gelu)(inputs)
  x = LayerNormalization()(x)

  x = Conv1D(filters=64, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  x = Conv1D(filters=128, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  x = Conv1D(filters=128, kernel_size=1, padding='same', activation=gelu)(x)

  sep = tcn_block(x, filters=128, dilation_rates=dilation_rates)
  sep = Conv1D(filters=128, kernel_size=1, padding='same', activation='sigmoid')(sep)

  mask = Multiply()([x, sep])
  mask = LayerNormalization()(mask)

  x = Conv1DTranspose(filters=64, strides=2, kernel_size=3, padding='same', activation=gelu)(mask)
  x = LayerNormalization()(x)

  x = Conv1DTranspose(filters=32, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  outputs = Conv1D(1, kernel_size=7, padding='same')(x)
  return Model(inputs, outputs, name='generator')

model = build_generator()
model.summary()
del model

In [None]:
def build_discriminator():
  dilation_rates = [1,2,4,8,16,32,64,128]

  inputs = Input(shape=(None, 1))

  noise = GaussianNoise(0.1)(inputs)

  x = Conv1D(filters=32, kernel_size=7, padding='same', activation=gelu)(noise)
  x = LayerNormalization()(x)

  x = Conv1D(filters=64, strides=2, kernel_size=3, padding='same', activation=gelu)(x)
  x = LayerNormalization()(x)

  x = Conv1D(filters=64, kernel_size=1, padding='same', activation=gelu)(x)

  x = tcn_block(x, filters=64, dilation_rates=dilation_rates)

  x = Conv1D(filters=64, kernel_size=1, padding='same', activation=gelu)(x)
  x = GlobalAveragePooling1D()(x)

  outputs = Dense(1)(x)

  return Model(inputs, outputs, name='discriminator')

model = build_discriminator()
model.summary()
del model

Function to create & compile the generator & discriminator neural networks

In [None]:
def create_gans():

    # create building blocks
    discriminator = build_discriminator()
    generator = build_generator()

    # compile discriminators while they're set to trainable
    optimizer = Adam(learning_rate=2e-4, beta_1=0.5)
    discriminator.compile(optimizer=optimizer, loss='mse')

    discriminator.trainable = False

    # create the first GAN architecture
    input_seq = Input(shape=(None,1))
    input_weights = Input(shape=(None,1))

    output_seq_b = generator(input_seq)
    global_out = discriminator(output_seq_b)
    mask_output_seq_b = Multiply(name='reconstruction')([output_seq_b, input_weights])

    GAN = Model([input_seq, input_weights], [global_out, mask_output_seq_b], name='GAN')

    optimizer = Adam(learning_rate=1e-4, beta_1=0.5)
    loss = ['mse', 'mae']
    loss_weights = [1, 10]

    GAN.compile(optimizer=optimizer, loss=loss, loss_weights=loss_weights)

    return GAN, generator, discriminator

Create the models

In [None]:
GAN, generator, discriminator = create_gans()

Define evaluation metric functions

In [None]:
def SNR(y_true, y_pred):
    diff = y_true - y_pred
    num = np.var(y_true)
    den = np.var(diff)
    snr = 10*np.log10(num/den)
    return snr

def freq_transform(data, Fs):
    # Window length of 1 second 
    nperseg = Fs*1
    # 75% overlap
    noverlap = np.ceil(nperseg * 0.75) 
    # Hann window
    window = 'hann'  
    # Zero-padding to the next power of two
    nfft = 2**np.ceil(np.log2(nperseg))

    freq, pxx = signal.welch(data, fs=Fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft)
    return pxx

def get_jsd(y_true, y_pred, Fs):
    jsd_all = []
    for i in range(y_pred.shape[0]):
        y_true_pxx = freq_transform(y_true[i].squeeze(), Fs=Fs)
        y_pred_pxx = freq_transform(y_pred[i].squeeze(), Fs=Fs)

        jsd_all.append(spatial.distance.jensenshannon(y_true_pxx, y_pred_pxx, base=2))
    return np.mean(np.array(jsd_all))

In [None]:
# Reset states generated by Keras
tf.keras.backend.clear_session()

Training Loop

In [None]:
# Core training loop, loads batch of sequences, generators and discriminators on batch
batch_size = 64

# Calculate number of batches
num_batches = x_train.shape[0] // batch_size

# total epochs 
epochs = 150

epoch_count = 1 # index of current epoch

# Empty arrays for storing loss
real_loss_all = []
fake_loss_all = []
gan_loss_all = []

while epoch_count <= epochs:
    start_time = time.time() # timer for the epoch to complete

    # shuffle training data
    x_train, y_train, y_target, label_train = utils.shuffle(x_train, y_train, y_target, label_train)

    print('EPOCH: ' + str(epoch_count))

    # Get a batch of data
    batch_index = 0  # index of the current batch

    # Store per batch loss within the epoch
    real_loss = []
    fake_loss = []
    gan_loss = []

    for i in range(num_batches):
        # Select a subset of the data corresponding to the indices of the batch
        seq_a_batch = x_train[batch_index:batch_index+batch_size]
        seq_b_batch = y_train[batch_index:batch_index+batch_size]
        masked_seq_a_batch = y_target[batch_index:batch_index+batch_size]
        labels_batch = label_train[batch_index:batch_index+batch_size]

        batch_index += batch_size

        target_batch = np.ones([len(seq_a_batch),1])
        fake_batch = np.zeros([len(seq_a_batch),1]) 

        # Train discriminator on real data
        loss = discriminator.train_on_batch(seq_b_batch, target_batch)
        real_loss.append(np.expand_dims(np.array(loss), axis=0))

        # Use the generator to create predictions for the discrminator
        seq_b_batch_fake = generator.predict(seq_a_batch, verbose=0)

        # Train discriminator on fake data
        loss = discriminator.train_on_batch(seq_b_batch_fake, fake_batch)
        fake_loss.append(np.expand_dims(np.array(loss), axis=0))
        
        # Fit each generator
        loss = GAN.train_on_batch([seq_a_batch, labels_batch], [target_batch, masked_seq_a_batch])
        gan_loss.append(np.expand_dims(np.array(loss), axis=0))

    gan_loss = np.mean(np.concatenate(gan_loss, axis=0), axis=0)
    real_loss = np.mean(np.concatenate(real_loss, axis=0), axis=0)
    fake_loss = np.mean(np.concatenate(fake_loss, axis=0), axis=0)

    print("--- %s seconds ---" % np.round((time.time() - start_time),2))
    print("Discriminator Real Data:", discriminator.metrics_names[0], "=", real_loss)
    print("Discriminator Fake Data:", discriminator.metrics_names[0], "=", fake_loss)
    print("GAN:", GAN.metrics_names[0], "=", gan_loss[0], "-", GAN.metrics_names[1], "=", gan_loss[1], "-", GAN.metrics_names[2], "=", gan_loss[2])
    
    gan_loss_all.append(np.expand_dims(gan_loss, axis=0))
    fake_loss_all.append(np.expand_dims(fake_loss, axis=0))
    real_loss_all.append(np.expand_dims(real_loss, axis=0))

    # Validate performance every 10 epochs
    if (epoch_count % 10 == 0):

        # Save models
        MODEL_PATH = '../../Models/Synthetic/discriminator_epoch' + str(epoch_count) + '.h5'
        discriminator.save(MODEL_PATH)

        MODEL_PATH = '../../Models/Synthetic/generator_epoch' + str(epoch_count) + '.h5'
        generator.save(MODEL_PATH)

        # Save current training performance
        np.save('../../Training Performance/Synthetic/real_loss.npy', np.concatenate(real_loss_all, axis=0))
        np.save('../../Training Performance/Synthetic/fake_loss.npy', np.concatenate(fake_loss_all, axis=0))
        np.save('../../Training Performance/Synthetic/gan_loss.npy', np.concatenate(gan_loss_all, axis=0))

        print("\nEpoch " + str(epoch_count) + " Prediction")

        # Predict on synthetic validation data
        y_pred = generator.predict(x_val, batch_size=32, verbose=0)

        plt.plot(x_val[0])
        plt.plot(y_pred[0])
        plt.title("Synthetic A>B Prediction")
        plt.show()

        snr = SNR(y_val, y_pred)
        print("Predicted Signal-Noise Ratio:", snr, "\n")

        average_jsd = get_jsd(y_val, y_pred, Fs=1920)
        print("Average Jensen-Shannon Divergence:", average_jsd, "Bits")

    epoch_count += 1
