### tsdGAN: A generative adversarial network approach for removing electrocardiographic interference from electromyographic signals 
Lucas Haberkamp<sup>1,2</sup> Charles A. Weisenbach<sup>1</sup> Peter Le<sup>3</sup>  
<sup>1</sup>Naval Medical Research Unit Dayton, Wright-Patterson Air Force Base, OH, USA   
<sup>2</sup>Leidos, Reston, VA, USA   
<sup>3</sup>Air Force Research Laboratory, 711th Human Performance Wing, Wright-Patterson Air Force Base, OH, USA

#### This notebook is used to eliminate ECG interference using dynamic filtration, as outlined in "Christov et al. (2018): Separation of electrocardiographic from electromyographic signals using dynamic filtration"

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
import plotly.graph_objects as go
import plotly.subplots as sp

Define helper function to load in EMG dataset

In [None]:
# fill_nan should be used with classification ground truth data. prevents a nan label.
def extractdf(data_path):
    file_list, df_list = [], [] # initialize lists
    for filename in sorted(os.listdir(data_path)):
        f = os.path.join(data_path, filename) 
        if os.path.splitext(filename)[1] == '.csv':
            current_file = os.path.splitext(filename)[0] # get the identifier of the participant from the file
            file_list.append(current_file)  
            tmp_df = pd.read_csv(f, header=13)
            df_list.append(tmp_df)

    return df_list, file_list

Load experimental validation dataset

In [None]:
# Specify paths
path = '../../Data/Raw TS EMG Data/Validation'

# Extract the data as dataframes stored into lists
data_list, file_list = extractdf(path)

print("Files in the dataset:", file_list)

Movement List:   
TrunkStability_DS_S20_EMG_Raw_12 - Deep Sagittal Flexion 20lb Lift  
TrunkStability_DS_S20_EMG_Raw_17 - Sit-to-Stand  
TrunkStability_DS_S20_EMG_Raw_47 - Prolonged Seated Task  
TrunkStability_DS_S20_EMG_Raw_56 - Standing Neutral

In [None]:
mvmt_list = ["Sagittal Flexion", "Sit-to-Stand", 
    "Seated Task", "Standing Neutral"]

# Define hyperparameters
muscle_list = data_list[0].columns[3:]
Fs = 1920

# Define a new set of more descriptive names for each EMG sensor
better_muscle_names = ['Right Erector Spinae',
    'Left Erector Spinae',
    'Right Internal Oblique',
    'Left Internal Oblique',
    'Right Latissimus Dorsi',
    'Left Latissimus Dorsi',
    'Right Rectus Abdominis',
    'Left Rectus Abdominis',
    'Right External Oblique',
    'Left External Oblique']

Define a Butterwoth filter function

In [None]:
# Function to butterworth filter the data
def butterfilter(x, Fc, Fs, type):
    Wn = np.asarray(Fc)/np.asarray(Fs/2)
    b, a = signal.butter(2, Wn, type)
    return signal.filtfilt(b,a,x)

Define Dynamic Filtration Function

In [None]:
def dynamic_ecg_filter(raw_data, fs, plot=False):
        
    pad_length = fs

    # Pad the data with zeros to allow template subtraction at ends of the signal
    emg_data = np.hstack([np.zeros(shape=(pad_length,)), raw_data.squeeze(), np.zeros(shape=(pad_length,))])

    # Section 1: Moving Average Filter (0.02s)
    win_size1 = round(fs * .02)
    ma_filt = pd.Series(emg_data).rolling(window=win_size1, min_periods=1, center=True).mean()

    # Section 2: Savitzky-Golay Filter (0.06s)
    win_size2 = round(fs * 0.06)
    # Make sure window size is not odd
    if not win_size2 % 2 == 0:
        win_size2 += 1 
    sg_filt = signal.savgol_filter(ma_filt, window_length=win_size2, polyorder=2)  # Assuming 2nd order

    # Section 3: Wings Calculation
    wings_win = round(fs * 0.01) # 0.01s
    wings = np.zeros_like(sg_filt)
    for ii in range(len(sg_filt)):
        if (ii > wings_win) and (ii < (len(emg_data)-wings_win)):
            wings[ii] = -np.abs((sg_filt[ii] - sg_filt[ii - wings_win]) * (sg_filt[ii] - sg_filt[ii + wings_win]))
        else:
            wings[ii] = 0

    # Section 4: Moving Average Filters (Apply 2 0.05s moving avg filters)
    ma_filt1 = pd.Series(wings).rolling(window=int(np.ceil(fs * 0.05)), min_periods=1, center=True).mean()
    ma_filt2 = pd.Series(ma_filt1).rolling(window=int(np.ceil(fs * 0.05)), min_periods=1, center=True).mean()

    if plot:
        t = np.arange(len(emg_data))/Fs

        plt.plot(t, emg_data, label='Raw EMG data')
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Time (s)")
        plt.legend()
        plt.show()

        plt.plot(t, wings, label='Wings')
        plt.plot(t, ma_filt2, label='MA_Filt2')
        plt.ylabel("Amplitude ($\mu$V)")
        plt.xlabel("Time (s)")
        plt.legend()
        plt.show()

    # Section 6: Normalization
    wmax, wmin = ma_filt2.max(), ma_filt2.min()
    nmin, nmax = 0, 100
    n = nmin + (nmax - nmin) * (ma_filt2 - wmin) / (wmax - wmin)

    # Section 7: Calculation of L
    len_data = len(emg_data)
    L = np.zeros_like(emg_data)
    for ii in range(len_data):
        if ii > 25 and ii < len_data - 25:
            L[ii] = np.sum(np.abs(emg_data[ii - 25:ii + 26]))
        elif ii <= 25:
            L[ii] = np.sum(np.abs(emg_data[:ii + 26]))  # Sum from the start to ii + 25
        else:  # ii >= len_data - 25
            L[ii] = np.sum(np.abs(emg_data[ii - 25:]))  # Sum from ii - 25 to the end

    # Section 8: Combination of Results
    n_lim = np.where(n < L, L, np.ceil(n))

    if plot:
        plt.plot(t, L, label='L')
        plt.plot(t, n, label='n')
        plt.plot(t, n_lim, label='n_min')
        plt.xlabel("Time (s)")
        plt.legend()
        plt.legend()
        plt.show()

    # Section 9: Ensure nLim is odd and greater than 2
    n_lim_est = np.ceil(n_lim).astype(int)

    n_lim_est_corrected = np.empty_like(n_lim_est)
    for i in range(len(n_lim_est)):
        if n_lim_est[i] % 2 == 1 and n_lim_est[i] > 2:
            n_lim_est_corrected[i] = n_lim_est[i]
        elif n_lim_est[i] % 2 == 1 and n_lim_est[i] <= 3:
            n_lim_est_corrected[i] = 3
        elif n_lim_est[i] == 0:
            n_lim_est_corrected[i] = 3
        else:
            n_lim_est_corrected[i] = n_lim_est[i] + 1

    # Section 10: Apply dynamic Savitzky-Golay filter
    max_window_length = n_lim_est_corrected.max() - 1
    pad_ends = np.zeros(max_window_length)
    zero_padded_data = np.concatenate((pad_ends, emg_data, pad_ends))
    ecg_filtered = np.zeros_like(emg_data)

    for ii in range(len(emg_data)):
        current_idx = ii + max_window_length
        iteration_length = (n_lim_est_corrected[ii] - 1) // 2
        data_to_filter = zero_padded_data[current_idx - iteration_length:current_idx + iteration_length + 1]
        filtered_data = signal.savgol_filter(data_to_filter, window_length=n_lim_est_corrected[ii], polyorder=2)
        ecg_filtered[ii] = filtered_data[iteration_length + 1]

    if plot:
        plt.plot(t, ecg_filtered, color='r', label='ECG Signal')
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude ($\mu$V)")
        plt.legend()
        plt.show()
    
    # Section 11: What we want 
    filt_emg = emg_data - ecg_filtered
    filt_emg = filt_emg[pad_length:-pad_length]

    if plot:
        plt.plot(raw_data, label='Raw Data')
        plt.plot(filt_emg, label='Dynamic ECG Filter')
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude ($\mu$V)")
        plt.legend()
        plt.show()
    return filt_emg

Process each trial with the dynamic filtration approach and create dynamic plots using Plotly

In [None]:
for mvmt, df in enumerate(data_list):
    df = df[muscle_list]
    df -= df.mean() # Ensure data has zero mean
    df = df.apply(lambda x: butterfilter(x, Fc=500, Fs=1920, type='low')) # Remove unwanted high-frequency components

    df_pred = df.apply(lambda x: dynamic_ecg_filter(x, fs=1920, plot=False)) # Apply the Template Subtraction filter
    
    # Define time-axis
    t = np.arange(len(df))/1920
    c = 0

    fig = sp.make_subplots(
        rows=5,
        cols=2,
        subplot_titles=[f"{mvmt_list[mvmt]}: {better_muscle_names[i]}" for i, col in enumerate(df.columns)],
        vertical_spacing=0.05,
        horizontal_spacing=0.05,  # Add horizontal_spacing to reduce the space between subplots horizontally
        specs=[[{}, {}]] * 5,
    )

    for i in range(5):
        for j in range(2):
            showlegend_flag = True if (i, j) == (0, 0) else False
            fig.add_trace(go.Scatter(x=t, y=df.iloc[:, c], name="Raw", legendgroup="Raw", line=dict(color="black", width=1.5), showlegend=showlegend_flag), row=i+1, col=j+1)
            fig.add_trace(go.Scatter(x=t, y=df_pred.iloc[:, c], name="Dynamic Filtration", legendgroup="Dynamic Filtration", showlegend=showlegend_flag, line=dict(color="#DC143C", width=1.25)), row=i+1, col=j+1)
            
            xaxis_title = "Time (s)" if i == 4 else None
            yaxis_title = "Amplitude (mV)" if j == 0 else None
            
            fig.update_xaxes(title_text=xaxis_title, row=i+1, col=j+1, showgrid=False, title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for x-axis
            fig.update_yaxes(title_text=yaxis_title, row=i+1, col=j+1, showgrid=False, title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for y-axis
            
            c += 1

        fig.update_layout(
            height=1200,
            width=1000,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, font=dict(size=12)),
            plot_bgcolor="white",
            margin=dict(l=10, r=10, t=100, b=50, autoexpand=True),  # set l and r to center the plot
        )
    
    html_string = f'''
    <!DOCTYPE html>
    <html>
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <style>
            .plotly-chart-container {{
                display: flex;
                justify-content: center;
                align-items: center;
                height: 100%;
                width: 100%;
            }}
        </style>
    </head>
    <body>
        <div class="plotly-chart-container">
            {fig.to_html(full_html=False, include_plotlyjs='cdn')}
        </div>
    </body>
    </html>
    '''

    with open(f"../../Plots/Experimental Validation Examples/Dynamic Filtration/Dynamic Filtration_{mvmt_list[mvmt]}.html", "w") as f:
        f.write(html_string)
        
    fig.show()