### ECGbe-GAN: A novel deep learning approach for eliminating ECG interference from EMG data
Lucas Haberkamp<sup>1,2,3</sup> Charles A. Weisenbach<sup>1,2</sup> Peter Le<sup>4</sup>  
<sup>1</sup>Naval Medical Research Unit Dayton, Wright-Patterson Air Force Base, OH, USA   
<sup>2</sup>Oak Ridge Institute for Science and Education, Oak Ridge, TN, USA   
<sup>3</sup>Leidos, Reston, VA, USA   
<sup>4</sup>Air Force Research Laboratory, 711th Human Performance Wing, Wright-Patterson Air Force Base, OH, USA 

#### This notebook is used to evaluate the 4th order Butterworth high-pass filter with a 30-Hz cutoff frequency on the experimental dataset

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
import plotly.graph_objects as go
import plotly.subplots as sp

ModuleNotFoundError: No module named 'dash'

Define helper function to load in EMG dataset

In [9]:
# fill_nan should be used with classification ground truth data. prevents a nan label.
def extractdf(data_path):
    file_list, df_list = [], [] # initialize lists
    for filename in sorted(os.listdir(data_path)):
        f = os.path.join(data_path, filename) 
        if os.path.splitext(filename)[1] == '.csv':
            current_file = os.path.splitext(filename)[0] # get the identifier of the participant from the file
            file_list.append(current_file)  
            tmp_df = pd.read_csv(f, header=13)
            df_list.append(tmp_df)

    return df_list, file_list

Load experimental validation dataset

In [10]:
# Specify paths
path = '../../Data/Raw TS EMG Data/Validation'

# Extract the data as dataframes stored into lists
data_list, file_list = extractdf(path)

print("Files in the dataset:", file_list)

Files in the dataset: ['TrunkStability_DS_S20_EMG_Raw_12', 'TrunkStability_DS_S20_EMG_Raw_17', 'TrunkStability_DS_S20_EMG_Raw_47', 'TrunkStability_DS_S20_EMG_Raw_56']


Movement List:   
TrunkStability_DS_S20_EMG_Raw_12 - Deep Sagittal Flexion 20lb Lift  
TrunkStability_DS_S20_EMG_Raw_17 - Sit-to-Stand  
TrunkStability_DS_S20_EMG_Raw_47 - Prolonged Seated Task  
TrunkStability_DS_S20_EMG_Raw_56 - Standing Neutral

In [11]:
mvmt_list = ["Sagittal Flexion", "Sit-to-Stand", 
    "Seated Task", "Standing Neutral"]

# Define hyperparameters
muscle_list = data_list[0].columns[3:]
Fs = 1920

# Define a new set of more descriptive names for each EMG sensor
better_muscle_names = ['Right Erector Spinae',
    'Left Erector Spinae',
    'Right Internal Oblique',
    'Left Internal Oblique',
    'Right Latissimus Dorsi',
    'Left Latissimus Dorsi',
    'Right Rectus Abdominis',
    'Left Rectus Abdominis',
    'Right External Oblique',
    'Left External Oblique']

Define a Butterwoth filter function

In [12]:
# Function to butterworth filter the data
def butterfilter(x, Fc, Fs, type):
    Wn = np.asarray(Fc)/np.asarray(Fs/2)
    b, a = signal.butter(2, Wn, type)
    return signal.filtfilt(b,a,x)

Process each trial with the high-pass filtering approach and create dynamic plots using Plotly

In [15]:
for mvmt, df in enumerate(data_list):
    df = df[muscle_list]
    df -= df.mean() # Ensure data has zero mean
    df = df.apply(lambda x: butterfilter(x, Fc=500, Fs=1920, type='low')) # Remove unwanted high-frequency components

    df_pred = df.apply(lambda x: butterfilter(x, Fc=30, Fs=1920, type='high')) # Apply the Butterworth high pass filter
    # save raw data to a .csv file
    df.to_csv(f"../../Data/Experimental Predictions/RAW_{mvmt_list[mvmt]}.csv", index=False)
    # save predicted data to a .csv file
    df_pred.to_csv(f"../../Data/Experimental Predictions/HPF_{mvmt_list[mvmt]}.csv", index=False)
    
    # Define time-axis
    t = np.arange(len(df))/1920
    c = 0

    start_idx = 0
    end_idx = -1

    fig = sp.make_subplots(
        rows=5,
        cols=2,
        subplot_titles=[f"HPF: {mvmt_list[mvmt]} - {better_muscle_names[i]}" for i, col in enumerate(df.columns)],
        vertical_spacing=0.05,
        horizontal_spacing=0.05,  # Add horizontal_spacing to reduce the space between subplots horizontally
        specs=[[{}, {}]] * 5,
    )

    for i in range(5):
        for j in range(2):
            showlegend_flag = True if (i, j) == (0, 0) else False
            fig.add_trace(go.Scatter(x=t[start_idx:end_idx], y=df.iloc[start_idx:end_idx, c], name="Contaminated EMG", legendgroup="Contaminated EMG", line=dict(color="black", width=2), showlegend=showlegend_flag), row=i+1, col=j+1)
            fig.add_trace(go.Scatter(x=t[start_idx:end_idx], y=df_pred.iloc[start_idx:end_idx, c], name="High Pass Filter", legendgroup="High Pass Filter", showlegend=showlegend_flag, line=dict(color="red", width=2)), row=i+1, col=j+1)
            
            xaxis_title = "Time (s)" if i == 4 else None
            yaxis_title = "Amplitude (mV)" if j == 0 else None
            
            fig.update_xaxes(title_text=xaxis_title, row=i+1, col=j+1, showgrid=True, gridwidth=1, gridcolor="LightGrey", title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for x-axis
            fig.update_yaxes(title_text=yaxis_title, row=i+1, col=j+1, showgrid=True, gridwidth=1, gridcolor="LightGrey", title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for y-axis
            
            c += 1

        fig.update_layout(
            height=1200,
            width=1000,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, font=dict(size=12)),
            plot_bgcolor="white",
            margin=dict(l=10, r=10, t=100, b=50, autoexpand=True),  # set l and r to center the plot
        )
    
    fig.show()

    # html_string = f'''
    # <!DOCTYPE html>
    # <html>
    # <head>
    #     <meta charset="UTF-8">
    #     <meta name="viewport" content="width=device-width, initial-scale=1">
    #     <style>
    #         .plotly-chart-container {{
    #             display: flex;
    #             justify-content: center;
    #             align-items: center;
    #             height: 100%;
    #             width: 100%;
    #         }}
    #     </style>
    # </head>
    # <body>
    #     <div class="plotly-chart-container">
    #         {fig.to_html(full_html=False, include_plotlyjs='cdn')}
    #     </div>
    # </body>
    # </html>
    # '''

    # with open(f"../../Plots/Experimental Validation Examples/High Pass Filter/HPF_{mvmt_list[mvmt]}.html", "w") as f:
    #     f.write(html_string)
        
    # fig.show()