### tsdGAN: A generative adversarial network approach for removing electrocardiographic interference from electromyographic signals 
Lucas Haberkamp<sup>1,2</sup> Charles A. Weisenbach<sup>1</sup> Peter Le<sup>3</sup>  
<sup>1</sup>Naval Medical Research Unit Dayton, Wright-Patterson Air Force Base, OH, USA   
<sup>2</sup>Leidos, Reston, VA, USA   
<sup>3</sup>Air Force Research Laboratory, 711th Human Performance Wing, Wright-Patterson Air Force Base, OH, USA

#### This notebook is used to eliminate ECG interference using template subtraction, as outlined in "Costa et al. (2018): A template subtraction method for reducing electrocardiographic artifacts in EMG signals of low intensity"

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
import plotly.graph_objects as go
import plotly.subplots as sp

Define helper function to load in EMG dataset

In [None]:
# fill_nan should be used with classification ground truth data. prevents a nan label.
def extractdf(data_path):
    file_list, df_list = [], [] # initialize lists
    for filename in sorted(os.listdir(data_path)):
        f = os.path.join(data_path, filename) 
        if os.path.splitext(filename)[1] == '.csv':
            current_file = os.path.splitext(filename)[0] # get the identifier of the participant from the file
            file_list.append(current_file)  
            tmp_df = pd.read_csv(f, header=13)
            df_list.append(tmp_df)

    return df_list, file_list

Load experimental validation dataset

In [None]:
# Specify paths
path = '../../Data/Raw TS EMG Data/Validation'

# Extract the data as dataframes stored into lists
data_list, file_list = extractdf(path)

print("Files in the dataset:", file_list)

Movement List:   
TrunkStability_DS_S20_EMG_Raw_12 - Deep Sagittal Flexion 20lb Lift  
TrunkStability_DS_S20_EMG_Raw_17 - Sit-to-Stand  
TrunkStability_DS_S20_EMG_Raw_47 - Prolonged Seated Task  
TrunkStability_DS_S20_EMG_Raw_56 - Standing Neutral

In [None]:
mvmt_list = ["Sagittal Flexion", "Sit-to-Stand", 
    "Seated Task", "Standing Neutral"]

# Define hyperparameters
muscle_list = data_list[0].columns[3:]
Fs = 1920

# Define a new set of more descriptive names for each EMG sensor
better_muscle_names = ['Right Erector Spinae',
    'Left Erector Spinae',
    'Right Internal Oblique',
    'Left Internal Oblique',
    'Right Latissimus Dorsi',
    'Left Latissimus Dorsi',
    'Right Rectus Abdominis',
    'Left Rectus Abdominis',
    'Right External Oblique',
    'Left External Oblique']

Define a Butterwoth filter function

In [None]:
# Function to butterworth filter the data
def butterfilter(x, Fc, Fs, type):
    Wn = np.asarray(Fc)/np.asarray(Fs/2)
    b, a = signal.butter(2, Wn, type)
    return signal.filtfilt(b,a,x)

Define Template Subtraction Function

In [None]:
def template_subtraction(raw_data, Fs, plotting=False):

    # QRS window = 0.16 seconds
    qrs_window = int((.16 * Fs) // 2)

    pad_length = 1920

    # Pad the data with zeros to allow template subtraction at ends of the signal
    data = np.hstack([np.zeros(shape=(pad_length,)), raw_data.squeeze(), np.zeros(shape=(pad_length,))])

    # Define the time array
    t = np.arange(len(data))/Fs

    # Bandpass filter between 4 & 50 Hz
    bandpass_filt = butterfilter(data, Fc=[4, 50], Fs=Fs, type='bandpass')

    if plotting:
        plt.plot(t, data, label='Raw')
        plt.plot(t, bandpass_filt, label='Bandpass Filtered')
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude ($\mu$V)")
        plt.legend()
        plt.show()

    # Rectify and compute moving averages
    rect = pd.Series(np.abs(bandpass_filt))
    ma1 = rect.rolling(window=int(0.1*Fs), center=True, min_periods=1).mean()
    ma2 = rect.rolling(window=int(1.0*Fs), center=True, min_periods=1).mean()

    if plotting:
        plt.plot(t, rect, label='Rectified')
        plt.plot(t, ma1, label='0.1s Moving Avg')
        plt.plot(t, ma2, label='1.0s Moving Avg')
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Time (s)")
        plt.legend()
        plt.show()

    # Find where 0.1s moving average > 1.0s moving average
    potential_peaks = ma1 - ma2
    bandpass_filt_baseline = bandpass_filt - ma2

    # Distance = RR interval for a max heart rate of 220 BPM
    distance = 0.27 * Fs

    # # Use find_peaks on the bandpass filtered signal [Fc = 4-50Hz]
    peaks, _ = signal.find_peaks(bandpass_filt_baseline, height=0, distance=distance) 
    
    # Ensure detected peaks occur where 0.1s moving avg > 1.0s moving avg
    filtered_peaks = []
    for peak in peaks:
        if potential_peaks[peak] >= 0:
            filtered_peaks.append(peak)

    if plotting:
        plt.plot(t, bandpass_filt_baseline, label='Bandpass Filtered (Detrended)')
        plt.scatter(t[filtered_peaks], bandpass_filt_baseline[filtered_peaks], color='r', marker='*', label='Peak')
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Time (s)")
        plt.legend()
        plt.show()

    # Obtain all detected QRS complexes
    qrs_template = []
    for peak in filtered_peaks:
        curr_qrs = bandpass_filt[peak-qrs_window:peak+qrs_window+1]
        qrs_template.append(curr_qrs)
        if plotting:
            plt.plot(curr_qrs, label=str(peak))
    # Determine average QRS complex
    qrs_template = np.mean(qrs_template, axis=0)

    if plotting:
        plt.plot(qrs_template, linewidth=3, color='k', label='Template')
        plt.legend(loc='right')
        plt.title("Detected QRS Complexes")
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Samples")
        plt.show()

    # Create ECG template. QRS template R peak abcisses must match the bandpass filtered abcisses
    ecg_template = np.zeros(len(data))
    for peak in filtered_peaks:
        adj_qrs_template = qrs_template * (bandpass_filt[peak] / np.max(qrs_template))
        ecg_template[peak-qrs_window:peak+qrs_window+1] = adj_qrs_template

    if plotting:
        plt.plot(t, bandpass_filt, label='Filtered')
        plt.plot(t, ecg_template, color='r', label='ECG Template')
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Time (s)")
        plt.legend()
        plt.show()

    # Subtract the ECG template from the raw data and remove the padding
    clean_emg = data - ecg_template
    clean_emg = clean_emg[pad_length:-pad_length]

    if plotting:
        data = data[pad_length:-pad_length]
        new_t = np.arange(len(data))/Fs

        plt.plot(new_t, data, label='Raw')
        plt.plot(new_t, clean_emg, label='Clean EMG')
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Time (s)")
        plt.legend()
        plt.show()

    return clean_emg

Process each trial with the template subtraction approach and create dynamic plots using Plotly

In [None]:
for mvmt, df in enumerate(data_list):
    df = df[muscle_list]
    df -= df.mean() # Ensure data has zero mean
    df = df.apply(lambda x: butterfilter(x, Fc=500, Fs=1920, type='low')) # Remove unwanted high-frequency components

    df_pred = df.apply(lambda x: template_subtraction(x, Fs=1920, plotting=False)) # Apply the Template Subtraction filter
    
    # Define time-axis
    t = np.arange(len(df))/1920
    c = 0

    fig = sp.make_subplots(
        rows=5,
        cols=2,
        subplot_titles=[f"{mvmt_list[mvmt]}: {better_muscle_names[i]}" for i, col in enumerate(df.columns)],
        vertical_spacing=0.05,
        horizontal_spacing=0.05,  # Add horizontal_spacing to reduce the space between subplots horizontally
        specs=[[{}, {}]] * 5,
    )

    for i in range(5):
        for j in range(2):
            showlegend_flag = True if (i, j) == (0, 0) else False
            fig.add_trace(go.Scatter(x=t, y=df.iloc[:, c], name="Raw", legendgroup="Raw", line=dict(color="black", width=1.5), showlegend=showlegend_flag), row=i+1, col=j+1)
            fig.add_trace(go.Scatter(x=t, y=df_pred.iloc[:, c], name="Template Subtraction", legendgroup="Template Subtraction", showlegend=showlegend_flag, line=dict(color="#E69F00", width=1.25)), row=i+1, col=j+1)
            
            xaxis_title = "Time (s)" if i == 4 else None
            yaxis_title = "Amplitude (mV)" if j == 0 else None
            
            fig.update_xaxes(title_text=xaxis_title, row=i+1, col=j+1, showgrid=False, title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for x-axis
            fig.update_yaxes(title_text=yaxis_title, row=i+1, col=j+1, showgrid=False, title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for y-axis
            
            c += 1

        fig.update_layout(
            height=1200,
            width=1000,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, font=dict(size=12)),
            plot_bgcolor="white",
            margin=dict(l=10, r=10, t=100, b=50, autoexpand=True),  # set l and r to center the plot
        )
    
    html_string = f'''
    <!DOCTYPE html>
    <html>
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <style>
            .plotly-chart-container {{
                display: flex;
                justify-content: center;
                align-items: center;
                height: 100%;
                width: 100%;
            }}
        </style>
    </head>
    <body>
        <div class="plotly-chart-container">
            {fig.to_html(full_html=False, include_plotlyjs='cdn')}
        </div>
    </body>
    </html>
    '''

    with open(f"../../Plots/Experimental Validation Examples/Template Subtraction/Template Subtraction_{mvmt_list[mvmt]}.html", "w") as f:
        f.write(html_string)
        
    fig.show()