### tsdGAN: A generative adversarial network approach for removing electrocardiographic interference from electromyographic signals 
Lucas Haberkamp<sup>1,2</sup> Charles A. Weisenbach<sup>1</sup> Peter Le<sup>3</sup>  
<sup>1</sup>Naval Medical Research Unit Dayton, Wright-Patterson Air Force Base, OH, USA   
<sup>2</sup>Leidos, Reston, VA, USA   
<sup>3</sup>Air Force Research Laboratory, 711th Human Performance Wing, Wright-Patterson Air Force Base, OH, USA 

#### This notebook is used to evaluate the tsdGAN model using the experimental dataset

In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from scipy import signal
import plotly.graph_objects as go
import plotly.subplots as sp

Define helper function to load in EMG dataset

In [None]:
def extractdf(data_path):
    file_list, df_list = [], [] 
    for filename in sorted(os.listdir(data_path)):
        f = os.path.join(data_path, filename) 
        if os.path.splitext(filename)[1] == '.csv':
            current_file = os.path.splitext(filename)[0]
            file_list.append(current_file)  
            tmp_df = pd.read_csv(f, header=13)
            df_list.append(tmp_df)

    return df_list, file_list

Load experimental validation dataset

In [None]:
# Specify paths
path = '../../Data/Raw TS EMG Data/Validation'

# Extract the data as dataframes stored into lists
data_list, file_list = extractdf(path)

print("Files in the dataset:", file_list)

Movement List:   
TrunkStability_DS_S20_EMG_Raw_12 - Deep Sagittal Flexion 20lb Lift  
TrunkStability_DS_S20_EMG_Raw_17 - Sit-to-Stand  
TrunkStability_DS_S20_EMG_Raw_47 - Prolonged Seated Task  
TrunkStability_DS_S20_EMG_Raw_56 - Standing Neutral

In [None]:
mvmt_list = ["Sagittal Flexion", "Sit-to-Stand", 
    "Seated Task", "Standing Neutral"]

# Define hyperparameters
muscle_list = data_list[0].columns[3:]
Fs = 1920

# Initialize scaler object
scaler = StandardScaler()

# Define a new set of more descriptive names for each EMG sensor
better_muscle_names = ['Right Erector Spinae',
    'Left Erector Spinae',
    'Right Internal Oblique',
    'Left Internal Oblique',
    'Right Latissimus Dorsi',
    'Left Latissimus Dorsi',
    'Right Rectus Abdominis',
    'Left Rectus Abdominis',
    'Right External Oblique',
    'Left External Oblique']

Load in tsdGAN model

In [None]:
model = tf.keras.models.load_model("../../Models/Experimental/experimental_generator_epoch150.h5", compile=False)

Define a Butterwoth filter function to remove high-frequency components

In [None]:
def butterfilter(x, Fc, Fs, type):
    Wn = np.asarray(Fc)/np.asarray(Fs/2)
    b, a = signal.butter(2, Wn, type)
    return signal.filtfilt(b,a,x)

Define the tsdGAN filter function that handles all scaling and pre-processing

In [None]:
def tsdGAN_filter(raw_data, window_size=2880):
    data = raw_data.values

    step_size = window_size // 3
    x_test = []

    # Generate windows
    for i in range(window_size, data.shape[0], step_size):
        x_test.append(data[i-window_size:i])
    
    # Handle the last window if there's remaining data
    last_data_start = i
    if last_data_start < data.shape[0]:
        x_test.append(data[-window_size:])
    
    x_test = np.array(x_test)
    
    # Scale the data
    x_test_scaled = []
    for window in x_test:
        x_test_scaled.append(scaler.fit_transform(window.reshape(-1, 1)))
    x_test_scaled = np.array(x_test_scaled)

    # Predict on scaled data
    y_pred = model.predict(x_test_scaled, batch_size=32)

    # Inverse transform the predictions
    y_pred_inverse = []
    for i in range(y_pred.shape[0]):
        scaler.fit(x_test[i].reshape(-1, 1))
        y_pred_inverse.append(scaler.inverse_transform(y_pred[i]))
    y_pred_inverse = np.array(y_pred_inverse)

    # Prepare the final array
    final_output = [y_pred_inverse[0][:1920]]  # Start with the first segment's initial part

    # Add middle segments, adjust if necessary to include more of each segment
    for seg in y_pred_inverse[1:-1]:
        final_output.append(seg[step_size:-step_size])

    # Handle the last segment to match the remaining data length
    final_len = data.shape[0] - last_data_start + step_size
    final_segment = y_pred_inverse[-1]
    final_output.append(final_segment[-final_len:])

    # Concatenate the adjusted segments
    final_output = np.concatenate(final_output, axis=0)

    return final_output.ravel()

Process each trial with tsdGAN and create dynamic plots using Plotly

In [None]:
for mvmt, df in enumerate(data_list):
    df = df[muscle_list]
    df -= df.mean() # Ensure data has zero mean
    df = df.apply(lambda x: butterfilter(x, Fc=500, Fs=1920, type='low')) # Remove unwanted high-frequency components

    df_pred = df.apply(lambda x: tsdGAN_filter(x)) # Apply ECGbe-GAN filter
    
    # Define time-axis
    t = np.arange(len(df))/1920
    
    c = 0
    
    start_idx = 0
    end_idx = -1

    fig = sp.make_subplots(
        rows=5,
        cols=2,
        subplot_titles=[f"{mvmt_list[mvmt]}: {better_muscle_names[i]}" for i, col in enumerate(df.columns)],
        vertical_spacing=0.05,
        horizontal_spacing=0.05,  # Add horizontal_spacing to reduce the space between subplots horizontally
        specs=[[{}, {}]] * 5,
    )

    for i in range(5):
        for j in range(2):
            showlegend_flag = True if (i, j) == (0, 0) else False
            fig.add_trace(go.Scatter(x=t[start_idx:end_idx], y=df.iloc[start_idx:end_idx, c], name="Raw", legendgroup="Raw", line=dict(color="black", width=1.5), showlegend=showlegend_flag), row=i+1, col=j+1)
            fig.add_trace(go.Scatter(x=t[start_idx:end_idx], y=df_pred.iloc[start_idx:end_idx, c], name="tsdGAN", legendgroup="tsdGAN", showlegend=showlegend_flag, line=dict(color="#56B4E9", width=1.25)), row=i+1, col=j+1)
            
            xaxis_title = "Time (s)" if i == 4 else None
            yaxis_title = "Amplitude (mV)" if j == 0 else None
            
            fig.update_xaxes(title_text=xaxis_title, row=i+1, col=j+1, showgrid=False, title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for x-axis
            fig.update_yaxes(title_text=yaxis_title, row=i+1, col=j+1, showgrid=False, title_standoff=10, tickfont=dict(size=10), title_font=dict(size=12))  # Update title_font for y-axis
            
            c += 1

        fig.update_layout(
            height=1200,
            width=1000,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, font=dict(size=12)),
            plot_bgcolor="white",
            margin=dict(l=10, r=10, t=100, b=50, autoexpand=True),  # set l and r to center the plot
        )

    html_string = f'''
    <!DOCTYPE html>
    <html>
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <style>
            .plotly-chart-container {{
                display: flex;
                justify-content: center;
                align-items: center;
                height: 100%;
                width: 100%;
            }}
        </style>
    </head>
    <body>
        <div class="plotly-chart-container">
            {fig.to_html(full_html=False, include_plotlyjs='cdn')}
        </div>
    </body>
    </html>
    '''

    with open(f"../../Plots/Experimental Validation Examples/tsdGAN/tsdGAN_{mvmt_list[mvmt]}.html", "w") as f:
        f.write(html_string)
    fig.show()