In [None]:
# Installation of BioSignalsNotebooks
# %pip install biosignalsnotebooks
# %pip install tqdm

In [None]:
# Imports
import os
import pandas as pd
import plotly.graph_objects as go
import biosignalsnotebooks as bsnb
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm 

In [None]:
# Function to create a dataframe to store the data of every subject across all circuits 
def subject_df_creator(subject_id):
    circuits_dfs_list = []
    output_msg=[]
    for circuit_number in tqdm( range(1, 51), desc=f"Concatenating Circuit Files for Subject AB{subject_id}"):
        filename = f"5362627/AB{subject_id}/AB{subject_id}/Raw/AB{subject_id}_Circuit_0{circuit_number:02d}_raw.csv"  
        if not os.path.exists(filename):  # Check if the file exists
            output_msg.append(f"0{circuit_number:02d}")
            continue
        df_circuit = pd.read_csv(filename)
        # Dropping the Unneeded Gyroscope and Speedometer Columns 
        df_circuit = df_circuit.drop(df_circuit.columns[0:30], axis=1)
        df_circuit = df_circuit.drop(df_circuit.columns[14:], axis=1)
        df_circuit['Subject'] = f"AB{subject_id}"
        circuits_dfs_list.append(df_circuit)
    
    # Concatenate all DataFrames in the list along the rows axis
    merged_df = pd.concat(circuits_dfs_list, ignore_index=True)
    if output_msg: 
        print(f"{len(output_msg)} Files do not exist:", output_msg)
    return merged_df

In [None]:
# Creating dataframes to save subject-specifc data
df_subject_156 = subject_df_creator("156")
df_subject_185 = subject_df_creator("185")
df_subject_186 = subject_df_creator("186")
df_subject_188 = subject_df_creator("188")
df_subject_189 = subject_df_creator("189")
df_subject_190 = subject_df_creator("190")
df_subject_191 = subject_df_creator("191")
df_subject_192 = subject_df_creator("192")
df_subject_193 = subject_df_creator("193")
df_subject_194 = subject_df_creator("194")

In [None]:
# Merging all the subject dataframes into one
df_all_subjects = pd.concat([df_subject_156, df_subject_185, df_subject_186, df_subject_188, df_subject_189, df_subject_190, df_subject_191, df_subject_192, df_subject_193, df_subject_194], ignore_index=True)

In [None]:
df_all_subjects[:10]

# Visualisation 

In [None]:
# adding figures and traces
fig1 = go.Figure()
fig1.add_trace(go.Scatter(x=df_all_subjects.index/1000, y=df_all_subjects['Right_TA'][60000:90000]))
fig1.update_layout( title="sEMG Signal: Sitting Vs Contraction Bursts Vs Rest", xaxis_title="Time (s)",
                    yaxis_title="sEMG Activity (V)", margin=dict(l=50, r=50, b=50, t=50, pad=4),
                    autosize=False, width=800, height=300)
# plotting
fig1.show()

## Analysis

In [None]:
# Studying mean, sigma and variance of the 14 Muscles
df_analysis = pd.DataFrame()
df_analysis['Muscles'] = df_all_subjects.columns[:-1] # to get the muscle names
df_analysis['Mean'] = df_all_subjects.iloc[:, :-1].mean(axis=1)
df_analysis['Std'] = df_all_subjects.iloc[:, :-1].std(axis=1)
df_analysis['Var'] = df_all_subjects.iloc[:, :-1].var(axis=1)
df_analysis

In [None]:
# Saving the detected bursts for every muscle 
sr = 1000 # sample rate = 1000Hz
sl = 20 # smooth level (Size of sliding window used during the moving average process) #used to be 40
th = 10 # threshold (To cover activation)

## TA
detected_bursts_right_TA = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_TA'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_TA  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_TA'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
## MG
detected_bursts_right_MG = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_MG'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_MG  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_MG'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
## SOL
detected_bursts_right_SOL = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_SOL'], sample_rate=sr, smooth_level=sl,
                                                        threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_SOL  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_SOL'], sample_rate=sr, smooth_level=sl,
                                                        threshold_level=th, time_units=True, device='CH0', plot_result= False)
## BF
detected_bursts_right_BF = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_BF'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_BF  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_BF'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
## ST
detected_bursts_right_ST = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_ST'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_ST  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_ST'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
## VL
detected_bursts_right_VL = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_VL'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_VL  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_VL'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
## RF
detected_bursts_right_RF = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Right_RF'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)
detected_bursts_left_RF  = bsnb.detect_emg_activations(emg_signal=df_all_subjects['Left_RF'], sample_rate=sr, smooth_level=sl,
                                                       threshold_level=th, time_units=True, device='CH0', plot_result= False)

In [None]:
# Visualising the EMG Burst Detection for the right Tibialis Anterior
plot_duration = 40000 # time in milliseconds
bsnb.detect_emg_activations(emg_signal = df_all_subjects['Right_TA'][:plot_duration], sample_rate = sr, smooth_level=sl, threshold_level=th, time_units=True, device='CH0', plot_result= True)
print('')

In [None]:
pd.DataFrame(detected_bursts_right_TA[0:3][:10]).transpose()

In [None]:
# Visualizing First Activations -> SEE CAPTURED WINDOW WITH RESPECT OF IDENTIFIED ACTIVATION
duration = 8000
shift = 2000
number_bursts_to_plot = 1

plt.rcParams["figure.figsize"] = (10,5)
fig = plt.figure()

plt.plot(df_all_subjects['Right_TA'][:duration], color="cornflowerblue")
for i in range(number_bursts_to_plot): # Plot first N bursts
    plt.axvline(detected_bursts_right_TA[0][i]*1000,color='red', label="Detected Burst Region") # ONSET VERTICAL LINE
    plt.axvline(detected_bursts_right_TA[1][i]*1000,color='red') # OFFSET VERTICAL LINE
    plt.axvline(detected_bursts_right_TA[0][i]*1000+400,color='black', label="Onset Window (500ms)") # ONSET VERTICAL LINE CORRECTED (START WINDOW)
    plt.axvline(detected_bursts_right_TA[0][i]*1000-100,color='black') # VERTICAL LINE (END WINDOW)
    
plt.legend(loc="upper left")
plt.xlim(shift,duration)
plt.grid()
plt.xlabel('Time (ms)', fontsize=10)
plt.ylabel('sEMG Intensity (V)', fontsize=10)

# plt.savefig("Window.png")

In [None]:
# adding figures and traces
fig1 = go.Figure()
fig1.add_trace(go.Scatter(x = df_all_subjects.index/1000 , y=df_all_subjects['Left_TA'][:10000]))

# formatting the plot
fig1.update_layout(autosize=True, title="sEMG Signal: Detected burst and corrected onset window",
                   xaxis_title="Time (s)", yaxis_title="sEMG Activity (V)", margin=dict(l=50, r=50, b=50, t=50, pad=4))

fig1.add_vrect(x0=detected_bursts_right_TA[0][0], x1=detected_bursts_left_TA[1][0], row="all", col=1,
               annotation_text="Detected Burst", annotation_position="top right", fillcolor="gray",
               opacity=0.25, line_width=0)

fig1.add_vline(x=detected_bursts_left_TA[0][0]+0.4,line_width=1.5, line_dash="dot", line_color="red")
fig1.add_vline(x=detected_bursts_left_TA[0][0]-0.1,line_width=1.5, line_dash="dot", line_color="red",
               annotation_text="Onset Window",annotation_position="bottom right")

# fig1.update_xaxes(range=[7.5, 20000/1000])
# fig1.update_yaxes(range=[-2, 2])
fig1.update_layout(autosize=False, width=800, height=300)
# plotting
fig1.show()

In [None]:
# adding figures and traces
fig1 = go.Figure()
fig1.add_trace(go.Scatter(x= df_all_subjects.index/1000, y=df_all_subjects['Right_TA']))
# formatting the plot
fig1.update_layout(autosize=True, title="sEMG Signal: Detection of Activation Bursts",
                   xaxis_title="Time (s)", yaxis_title="sEMG Activity (V)",
                   margin=dict(l=50, r=50, b=50, t=50, pad=4))

for i in range(len(detected_bursts_right_TA[0])):
    fig1.add_vrect(x0=detected_bursts_right_TA[0][i], x1=detected_bursts_right_TA[1][i], row="all", col=1,
                   annotation_text="Detected Burst", annotation_position="top right",
                   fillcolor="black", opacity=0.25, line_width=0)

# fig1.update_xaxes(range=[30, 60])
fig1.update_layout(autosize=False, width=800, height=300)
# plotting
fig1.show()

In [ ]:
# adding figures and traces
fig1 = go.Figure()
fig1.add_trace(go.Scatter(x= df_all_subjects.index/1000, y=df_all_subjects['Left_ST']))
# formatting the plot
fig1.update_layout(autosize=True, title="sEMG Signal: Detection of Activation Bursts",
                   xaxis_title="Time (s)", yaxis_title="sEMG Activity (V)",
                   margin=dict(l=50, r=50, b=50, t=50, pad=4))

for i in range(len(detected_bursts_left_ST[0])):
    fig1.add_vrect(x0=detected_bursts_left_ST[0][i], x1=detected_bursts_left_ST[1][i], row="all", col=1,
                   annotation_text="Detected Burst", annotation_position="top right",
                   fillcolor="black", opacity=0.25, line_width=0)

# fig1.update_xaxes(range=[30, 60])
fig1.update_layout(autosize=False, width=800, height=300)
# plotting
fig1.show()

In [None]:
# Calculating the total number of bursts per muscle
tot_bursts_right_TA = len(detected_bursts_right_TA[0]); tot_bursts_left_TA = len(detected_bursts_left_TA[0])
tot_bursts_right_MG = len(detected_bursts_right_MG[0]); tot_bursts_left_MG = len(detected_bursts_left_MG[0])
tot_bursts_right_SOL= len(detected_bursts_right_SOL[0]); tot_bursts_left_SOL= len(detected_bursts_left_SOL[0])
tot_bursts_right_BF = len(detected_bursts_right_BF[0]); tot_bursts_left_BF = len(detected_bursts_left_BF[0])
tot_bursts_right_ST = len(detected_bursts_right_ST[0]); tot_bursts_left_ST = len(detected_bursts_left_ST[0])
tot_bursts_right_VL = len(detected_bursts_right_VL[0]); tot_bursts_left_VL = len(detected_bursts_left_VL[0])
tot_bursts_right_RF = len(detected_bursts_right_RF[0]); tot_bursts_left_RF = len(detected_bursts_left_RF[0])

In [None]:
# Printing the total number of bursts per muscle
print("- Number of Identified Bursts:\n",
      "   Right TA:\t", tot_bursts_right_TA, "\t", "   Left TA:\t", tot_bursts_left_TA,"\n",
      "   Right MG:\t", tot_bursts_right_MG, "\t", "   Left MG:\t", tot_bursts_left_MG,"\n",
      "   Right SOL:\t", tot_bursts_right_SOL, "\t", "   Left SOL:\t", tot_bursts_left_SOL,"\n",
      "   Right BF:\t", tot_bursts_right_BF, "\t", "   Left BF:\t", tot_bursts_left_BF,"\n",
      "   Right ST:\t", tot_bursts_right_ST, "\t", "   Left ST:\t", tot_bursts_left_ST,"\n",
      "   Right VL:\t", tot_bursts_right_VL, "\t", "   Left VL:\t", tot_bursts_left_VL,"\n",
      "   Right RF:\t", tot_bursts_right_RF, "\t", "   Left RF:\t", tot_bursts_left_RF,"\n")

In [None]:
# Calculating average burst length per muscle
## TA
average_burst_length_right_TA = np.mean(np.array(detected_bursts_right_TA[1])-np.array(detected_bursts_right_TA[0]))*1000
average_burst_length_left_TA = np.mean(np.array(detected_bursts_left_TA[1])-np.array(detected_bursts_left_TA[0]))*1000
## MG
average_burst_length_right_MG = np.mean(np.array(detected_bursts_right_MG[1])-np.array(detected_bursts_right_MG[0]))*1000
average_burst_length_left_MG = np.mean(np.array(detected_bursts_left_MG[1])-np.array(detected_bursts_left_MG[0]))*1000
## SOL
average_burst_length_right_SOL = np.mean(np.array(detected_bursts_right_SOL[1])-np.array(detected_bursts_right_SOL[0]))*1000
average_burst_length_left_SOL = np.mean(np.array(detected_bursts_left_SOL[1])-np.array(detected_bursts_left_SOL[0]))*1000
## BF
average_burst_length_right_BF = np.mean(np.array(detected_bursts_right_BF[1])-np.array(detected_bursts_right_BF[0]))*1000
average_burst_length_left_BF = np.mean(np.array(detected_bursts_left_BF[1])-np.array(detected_bursts_left_BF[0]))*1000
## ST
average_burst_length_right_ST = np.mean(np.array(detected_bursts_right_ST[1])-np.array(detected_bursts_right_ST[0]))*1000
average_burst_length_left_ST = np.mean(np.array(detected_bursts_left_ST[1])-np.array(detected_bursts_left_ST[0]))*1000
## VL
average_burst_length_right_VL = np.mean(np.array(detected_bursts_right_VL[1])-np.array(detected_bursts_right_VL[0]))*1000
average_burst_length_left_VL = np.mean(np.array(detected_bursts_left_VL[1])-np.array(detected_bursts_left_VL[0]))*1000
## RF
average_burst_length_right_RF = np.mean(np.array(detected_bursts_right_RF[1])-np.array(detected_bursts_right_RF[0]))*1000
average_burst_length_left_RF = np.mean(np.array(detected_bursts_left_RF[1])-np.array(detected_bursts_left_RF[0]))*1000

In [None]:
# Printing average burst length per muscle
print("- Avg. Length:\n",
      "   Right TA:\t", round(average_burst_length_right_TA,2), "ms" , "\t", "   Left TA:\t", round(average_burst_length_left_TA,2),"ms", "\n",
      "   Right MG:\t", round(average_burst_length_right_MG, 2), "ms" , "\t", "   Left MG:\t", round(average_burst_length_left_MG, 2), "ms", "\n",
      "   Right SOL:\t", round(average_burst_length_right_SOL, 2), "ms" , "\t", "   Left SOL:\t", round(average_burst_length_left_SOL, 2), "ms", "\n",
      "   Right BF:\t", round(average_burst_length_right_BF, 2), "ms" , "\t", "   Left BF:\t", round(average_burst_length_left_BF, 2), "ms", "\n",
      "   Right ST:\t", round(average_burst_length_right_ST, 2), "ms" , "\t", "   Left ST:\t", round(average_burst_length_left_ST, 2), "ms", "\n",
      "   Right VL:\t", round(average_burst_length_right_VL, 2), "ms" , "\t", "   Left VL:\t", round(average_burst_length_left_VL, 2), "ms", "\n",
      "   Right RF:\t", round(average_burst_length_right_RF, 2), "ms" , "\t", "   Left RF:\t", round(average_burst_length_left_RF, 2), "ms", "\n")

In [None]:
muscles = ['Right Tibialis Ant.', 'Left Tibialis Ant.', 'Right MG', 'Left MG', 'Right SOL', 'Left SOL', 'Right BF', 'Left BF', 'Right ST', 'Left ST', 'Right VL', 'Left VL', 'Right RF', 'Left RF']
burst_data = [
    detected_bursts_right_TA, detected_bursts_left_TA,
    detected_bursts_right_MG, detected_bursts_left_MG,
    detected_bursts_right_SOL, detected_bursts_left_SOL,
    detected_bursts_right_BF, detected_bursts_left_BF,
    detected_bursts_right_ST, detected_bursts_left_ST,
    detected_bursts_right_VL, detected_bursts_left_VL,
    detected_bursts_right_RF, detected_bursts_left_RF]

# Histogram
# f,a = plt.subplots(7,2)
# f.set_size_inches(13,30)
# a = a.ravel()

# for i in range(len(muscles)):
#     a[i].hist(np.array(burst_data[i][1])-np.array(burst_data[i][0]), bins=40, alpha=0.5, label='test_subject')
#     a[i].legend(loc='upper right')
#     a[i].set_xlabel("Burst Duration (seconds)")
#     a[i].set_ylabel("Occurrences")
#     a[i].set_title('Histogram Burst Duration: ' + muscles[i])
# 
# plt.tight_layout()
# plt.show()

In [None]:
tot_bursts_right_TA = 0; tot_bursts_left_TA = 0
tot_bursts_right_MG = 0; tot_bursts_left_MG = 0
tot_bursts_right_SOL= 0; tot_bursts_left_SOL= 0
tot_bursts_right_BF = 0; tot_bursts_left_BF = 0
tot_bursts_right_ST = 0; tot_bursts_left_ST = 0
tot_bursts_right_VL = 0; tot_bursts_left_VL = 0
tot_bursts_right_RF = 0; tot_bursts_left_RF = 0

average_burst_length_right_TA = []; average_burst_length_left_TA = []
average_burst_length_right_MG = []; average_burst_length_left_MG = []
average_burst_length_right_SOL= []; average_burst_length_left_SOL= []
average_burst_length_right_BF = []; average_burst_length_left_BF = []
average_burst_length_right_ST = []; average_burst_length_left_ST = []
average_burst_length_right_VL = []; average_burst_length_left_VL = []
average_burst_length_right_RF = []; average_burst_length_left_RF = []

# List of unique subjects
subjects = list(df_all_subjects['Subject'].unique())

for subject in range(len(subjects)):
    tot_bursts_right_TA += len(detected_bursts_right_TA[subject][0]); tot_bursts_left_TA += len(detected_bursts_left_TA[subject][0])
    # tot_bursts_right_MG += len(detected_bursts_right_MG[subject][0]); tot_bursts_left_MG += len(detected_bursts_left_MG[subject][0])
    # tot_bursts_right_SOL += len(detected_bursts_right_SOL[subject][0]); tot_bursts_left_SOL += len(detected_bursts_left_SOL[subject][0])
    # tot_bursts_right_BF += len(detected_bursts_right_BF[subject][0]); tot_bursts_left_BF += len(detected_bursts_left_BF[subject][0])
    # tot_bursts_right_ST += len(detected_bursts_right_ST[subject][0]); tot_bursts_left_ST += len(detected_bursts_left_ST[subject][0])
    # tot_bursts_right_VL += len(detected_bursts_right_VL[subject][0]); tot_bursts_left_VL += len(detected_bursts_left_VL[subject][0])
    # tot_bursts_right_RF += len(detected_bursts_right_RF[subject][0]); tot_bursts_left_RF += len(detected_bursts_left_RF[subject][0])
    # 
    average_burst_length_right_TA.append(np.mean(np.array(detected_bursts_right_TA[subject][1])-np.array(detected_bursts_right_TA[subject][0]))*1000)
    average_burst_length_left_TA.append(np.mean(np.array(detected_bursts_left_TA[subject][1])-np.array(detected_bursts_left_TA[subject][0]))*1000)
    # average_burst_length_right_MG.append(np.mean(np.array(detected_bursts_right_MG[subject][1])-np.array(detected_bursts_right_MG[subject][0]))*1000)
    # average_burst_length_left_MG.append(np.mean(np.array(detected_bursts_left_MG[subject][1])-np.array(detected_bursts_left_MG[subject][0]))*1000)
    # average_burst_length_right_SOL.append(np.mean(np.array(detected_bursts_right_SOL[subject][1])-np.array(detected_bursts_right_SOL[subject][0]))*1000)
    # average_burst_length_left_SOL.append(np.mean(np.array(detected_bursts_left_SOL[subject][1])-np.array(detected_bursts_left_SOL[subject][0]))*1000)
    # average_burst_length_right_BF.append(np.mean(np.array(detected_bursts_right_BF[subject][1])-np.array(detected_bursts_right_BF[subject][0]))*1000)
    # average_burst_length_left_BF.append(np.mean(np.array(detected_bursts_left_BF[subject][1])-np.array(detected_bursts_left_BF[subject][0]))*1000)
    # average_burst_length_right_ST.append(np.mean(np.array(detected_bursts_right_ST[subject][1])-np.array(detected_bursts_right_ST[subject][0]))*1000)
    # average_burst_length_left_ST.append(np.mean(np.array(detected_bursts_left_ST[subject][1])-np.array(detected_bursts_left_ST[subject][0]))*1000)
    # average_burst_length_right_VL.append(np.mean(np.array(detected_bursts_right_VL[subject][1])-np.array(detected_bursts_right_VL[subject][0]))*1000)
    # average_burst_length_left_VL.append(np.mean(np.array(detected_bursts_left_VL[subject][1])-np.array(detected_bursts_left_VL[subject][0]))*1000)
    # average_burst_length_right_RF.append(np.mean(np.array(detected_bursts_right_RF[subject][1])-np.array(detected_bursts_right_RF[subject][0]))*1000)
    # average_burst_length_left_RF.append(np.mean(np.array(detected_bursts_left_RF[subject][1])-np.array(detected_bursts_left_RF[subject][0]))*1000)


In [None]:
len(detected_bursts_right_TA[0])

In [None]:
# Histogram
f,a = plt.subplots(7,2)
f.set_size_inches(13,30)
a = a.ravel()

# Colors for each subject
colors = plt.cm.viridis(np.linspace(0, 1, len(subjects)))

for j, subject in enumerate(subjects):
    for i in range(len(muscles)):
        # Filter data for each subject
        subject_data = df_all_subjects[df_all_subjects['Subject'] == subject] 
        a[i].hist(np.array(burst_data[i][1])-np.array(burst_data[i][0]), bins=40, alpha=0.5, label=subject, color=colors[j])
        a[i].legend(loc='upper right')
        a[i].set_xlabel("Burst Duration (seconds)")
        a[i].set_ylabel("Occurrences")
        a[i].set_title('Histogram Burst Duration: ' + muscles[i])

plt.tight_layout()
plt.show()

In [None]:
leave_one_out = 7  # Specify which one to leave out (None: None, # 0:Subject1, 1:Subject2, 2:Subject3, 3:Subject4, 4:Subject5 ...)

if leave_one_out != None:
    loo_activation_burst_left_arm = [activation_burst_left_arm.pop(leave_one_out)]
    loo_activation_burst_right_arm = [activation_burst_right_arm.pop(leave_one_out)]
    loo_activation_burst_left_leg = [activation_burst_left_leg.pop(leave_one_out)]
    loo_activation_burst_right_leg = [activation_burst_right_leg.pop(leave_one_out)]
    loo_emg_series_left_arm = [emg_series_left_arm.pop(leave_one_out)]
    loo_emg_series_right_arm = [emg_series_right_arm.pop(leave_one_out)]
    loo_emg_series_left_leg = [emg_series_left_leg.pop(leave_one_out)]
    loo_emg_series_right_leg = [emg_series_right_leg.pop(leave_one_out)]
    print("{} was excluded from the training!".format(subjects[leave_one_out]))


# Extracting Bursts
### Method: Preserving Onset and Window = 500ms
def extract_burst_windows(emg_muscle, onset_list, window, left_shift, channels,
                          muscle):  #CAREFUL HERE WITH NUMBER OF TOTAL SUBJECTS. DO NOT REPEAT SAME SUBJECT
    burst_samples = []
    for subject in tqdm(range(len(emg_muscle)), desc="{} Extraction Burst Process (Fixed Window)".format(muscle),
                        position=0, leave=True):
        for onset in onset_list[subject][0]:
            onset_ms = int(onset * 1000) - left_shift
            current_sample = []
            for timestep in range(window):
                current_time = []
                for ch in range(channels):
                    current_time.append(emg_muscle[subject][onset_ms + timestep, ch + 1])
                current_sample.append(current_time)
            burst_samples.append(current_sample)
    return np.array(burst_samples)


window = 500  # in ms (total window size)
left_shift = 100  # in ms (left shift from detected onset). See detected onset on the vertical red lines in the plots above
channels = 16
fixed_bursts_left_arm = extract_burst_windows(emg_series_left_arm, activation_burst_left_arm, window, left_shift,
                                              channels, 'Left Arm')
fixed_bursts_right_arm = extract_burst_windows(emg_series_right_arm, activation_burst_right_arm, window, left_shift,
                                               channels, 'Right Arm')
fixed_bursts_left_leg = extract_burst_windows(emg_series_left_leg, activation_burst_left_leg, window, left_shift,
                                              channels, 'Left Leg')
fixed_bursts_right_leg = extract_burst_windows(emg_series_right_leg, activation_burst_right_leg, window, left_shift,
                                               channels, 'Right Leg')

if leave_one_out != None:
    loo_fixed_bursts_left_arm = extract_burst_windows(loo_emg_series_left_arm, loo_activation_burst_left_arm, window,
                                                      left_shift, channels, 'Leave one Out: Left Arm')
    loo_fixed_bursts_right_arm = extract_burst_windows(loo_emg_series_right_arm, loo_activation_burst_right_arm, window,
                                                       left_shift, channels, 'Leave one Out: Right Arm')
    loo_fixed_bursts_left_leg = extract_burst_windows(loo_emg_series_left_leg, loo_activation_burst_left_leg, window,
                                                      left_shift, channels, 'Leave one Out: Left Leg')
    loo_fixed_bursts_right_leg = extract_burst_windows(loo_emg_series_right_leg, loo_activation_burst_right_leg, window,
                                                       left_shift, channels, 'Leave one Out: Right Leg')

print("Shape of Muscle Bursts (Fixed Window):\n",
      "   Left Arm:\t", fixed_bursts_left_arm.shape, "\n", "   Right Arm:\t", fixed_bursts_right_arm.shape, "\n",
      "   Left Leg:\t", fixed_bursts_left_leg.shape, "\n", "   Right Leg:\t", fixed_bursts_right_leg.shape, "\n")
if leave_one_out != None:
    print("------------------------\nLeave One Out: Shape of Muscle Bursts (Fixed Window):\n",
          "   Left Arm:\t", loo_fixed_bursts_left_arm.shape, "\n", "   Right Arm:\t", loo_fixed_bursts_right_arm.shape,
          "\n",
          "   Left Leg:\t", loo_fixed_bursts_left_leg.shape, "\n", "   Right Leg:\t", loo_fixed_bursts_right_leg.shape)
## Saving Leave One Out as TFRecord
loo_emg_series_complete = [loo_fixed_bursts_left_arm, loo_fixed_bursts_right_arm, loo_fixed_bursts_left_leg,
                           loo_fixed_bursts_right_leg]

with tf.io.TFRecordWriter('leave_one_out.tfrecord') as tfrecord:
    for emg_muscle in tqdm(range(len(loo_emg_series_complete)),
                           desc="Extracting loo dataset to TFRecords (Fixed Window)".format(
                                   len(loo_emg_series_complete)), position=0, leave=True):
        for sample in range(len(loo_emg_series_complete[emg_muscle])):
            for ch in range(channels):
                sample_in_channel = loo_emg_series_complete[emg_muscle][sample][:, ch]
                features = {
                    'label': tf.train.Feature(
                        float_list=tf.train.FloatList(value=tf.keras.utils.to_categorical(emg_muscle, 4))),
                    'feature': tf.train.Feature(float_list=tf.train.FloatList(value=sample_in_channel))
                }
                example = tf.train.Example(features=tf.train.Features(feature=features))
                tfrecord.write(example.SerializeToString())


# Iterate over the whole dataset to count records/samples (https://www.rustyrobotics.com/posts/tensorflow/tfdataset-record-count/)
# Reference: https://www.rustyrobotics.com/posts/tensorflow/tfdataset-record-count/
def countRecords(ds: tf.data.Dataset):
    count = 0
    if tf.executing_eagerly():
        # TF v2 or v1 in eager mode
        for r in ds:
            count = count + 1
    else:
        # TF v1 in non-eager mode
        iterator = tf.compat.v1.data.make_one_shot_iterator(ds)
        next_batch = iterator.get_next()
        with tf.compat.v1.Session() as sess:
            try:
                while True:
                    sess.run(next_batch)
                    count = count + 1
            except tf.errors.OutOfRangeError:
                pass
    return count


AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 1024
window = 500


def loo_read_tfrecord(serialized_example):
    tfrecord_format = (
        {
            'label': tf.io.FixedLenFeature([4], tf.float32),
            'feature': tf.io.FixedLenFeature([500], tf.float32),
        }
    )
    example = tf.io.parse_single_example(serialized_example, tfrecord_format)
    f = tf.reshape(example['feature'], [window, 1])
    f.set_shape([window, 1])
    return f, example['label']


def loo_get_dataset(tf_record_name):
    dataset = tf.data.TFRecordDataset(tf_record_name)
    dataset = dataset.map(loo_read_tfrecord, num_parallel_calls=AUTOTUNE)
    dataset_samples = countRecords(dataset)
    dataset = dataset.shuffle(dataset_samples)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset


loo_dataset = loo_get_dataset('leave_one_out.tfrecord')
for feature, label in loo_dataset:
    print('label={}, feature={}'.format(label.shape, feature.shape))
# TFRecords: Storing Training and Validation Datasets in Tensorflow Records

Reference: https: // colab.research.google.com / github / tensorflow / docs / blob / master / site / en / tutorials / load_data / tfrecord.ipynb  #scrollTo=_e3g9ExathXP


## Write TFRecords
def extract_burst_windows_tfrecord(emg_series_complete, onset_list, window, left_shift,
                                   channels):  #CAREFUL HERE WITH NUMBER OF TOTAL SUBJECTS. DO NOT REPEAT SAME SUBJECT
    with tf.io.TFRecordWriter('all_dataset.tfrecord') as tfrecord:
        for emg_muscle in tqdm(range(len(emg_series_complete)),
                               desc="Extracting dataset to TFRecords (Fixed Window)".format(len(emg_series_complete)),
                               position=0, leave=True):
            for subject in range(len(emg_series_complete[emg_muscle])):
                burst = 0
                for onset in onset_list[emg_muscle][subject][0]:
                    onset_ms = int(onset * 1000) - left_shift
                    burst += 1
                    current_sample = []
                    for ch in range(channels):
                        current_sample.append(
                            emg_series_complete[emg_muscle][subject][onset_ms:onset_ms + window, ch + 1])
                    features = {
                        #                         'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[emg_muscle])),
                        'label': tf.train.Feature(
                            float_list=tf.train.FloatList(value=tf.keras.utils.to_categorical(emg_muscle, 4))),
                        'feature_ch1': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[0])),
                        'feature_ch2': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[1])),
                        'feature_ch3': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[2])),
                        'feature_ch4': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[3])),
                        'feature_ch5': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[4])),
                        'feature_ch6': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[5])),
                        'feature_ch7': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[6])),
                        'feature_ch8': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[7])),
                        'feature_ch9': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[8])),
                        'feature_ch10': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[9])),
                        'feature_ch11': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[10])),
                        'feature_ch12': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[11])),
                        'feature_ch13': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[12])),
                        'feature_ch14': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[13])),
                        'feature_ch15': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[14])),
                        'feature_ch16': tf.train.Feature(float_list=tf.train.FloatList(value=current_sample[15])),
                        'subject': tf.train.Feature(int64_list=tf.train.Int64List(value=[subject + 1])),
                        'burst': tf.train.Feature(int64_list=tf.train.Int64List(value=[burst])),
                        'channel': tf.train.Feature(int64_list=tf.train.Int64List(value=[ch + 1]))
                    }
                    example = tf.train.Example(features=tf.train.Features(feature=features))
                    tfrecord.write(example.SerializeToString())
    return


emg_series_tot = [emg_series_left_arm, emg_series_right_arm, emg_series_left_leg, emg_series_right_leg]
activation_burst_tot = [activation_burst_left_arm, activation_burst_right_arm, activation_burst_left_leg,
                        activation_burst_right_leg]
extract_burst_windows_tfrecord(emg_series_tot, activation_burst_tot, window, left_shift, channels)


## Read TFRecords
def map_fn(serialized_example):
    features = {
        'label': tf.io.FixedLenFeature([4], tf.float32),
        'feature_ch1': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch2': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch3': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch4': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch5': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch6': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch7': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch8': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch9': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch10': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch11': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch12': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch13': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch14': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch15': tf.io.FixedLenFeature([500], tf.float32),
        'feature_ch16': tf.io.FixedLenFeature([500], tf.float32),
        'subject': tf.io.FixedLenFeature([], tf.int64),
        'burst': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(serialized_example, features)
    return example['label'], example['feature_ch1'], example['feature_ch2'], example['feature_ch3'], example[
        'feature_ch4'], example['feature_ch5'], example['feature_ch6'], example['feature_ch7'], example['feature_ch8'], \
    example['feature_ch9'], example['feature_ch10'], example['feature_ch11'], example['feature_ch12'], example[
        'feature_ch13'], example['feature_ch14'], example['feature_ch15'], example['feature_ch16'], example['subject'], \
    example['burst']


dataset = tf.data.TFRecordDataset('all_dataset.tfrecord')
dataset = dataset.map(map_fn)

for label, ch1, ch2, ch3, ch4, ch5, ch6, ch7, ch8, ch9, ch10, ch11, ch12, ch13, ch14, ch15, ch16, subject, burst in dataset:
    print('label={}, subject={}, burst={}'.format(label, subject, burst))


# plt.plot(ch1)
def separate_dataset_per_subject_train_val(dataset, subj, train_percentage):
    # Filtering whole dataset TFRECORDS by subjects:
    dataset_subject = dataset.filter(
        lambda label, ch1, ch2, ch3, ch4, ch5, ch6, ch7, ch8, ch9, ch10, ch11, ch12, ch13, ch14, ch15, ch16, subject,
               burst: subject == subj)
    # Count Total Samples for each Subject Dataset
    dataset_subject_samples = countRecords(dataset_subject)
    # Shuffling bursts per subject
    dataset_subject_shuffled = dataset_subject.shuffle(dataset_subject_samples)
    # Separating Subject Training and Evaluation Datasets:
    dataset_subject_1_train = dataset_subject_shuffled.take(int(dataset_subject_samples * train_percentage))
    dataset_subject_1_val = dataset_subject_shuffled.skip(int(dataset_subject_samples * train_percentage)).take(
        dataset_subject_samples - int(dataset_subject_samples * train_percentage))
    return dataset_subject_1_train, dataset_subject_1_val


train_percentage = 0.8
dataset_subject1_train, dataset_subject1_val = separate_dataset_per_subject_train_val(dataset, 1, train_percentage)
dataset_subject2_train, dataset_subject2_val = separate_dataset_per_subject_train_val(dataset, 2, train_percentage)
dataset_subject3_train, dataset_subject3_val = separate_dataset_per_subject_train_val(dataset, 3, train_percentage)
dataset_subject4_train, dataset_subject4_val = separate_dataset_per_subject_train_val(dataset, 4, train_percentage)
dataset_subject5_train, dataset_subject5_val = separate_dataset_per_subject_train_val(dataset, 5, train_percentage)
dataset_subject6_train, dataset_subject6_val = separate_dataset_per_subject_train_val(dataset, 6, train_percentage)
dataset_subject7_train, dataset_subject7_val = separate_dataset_per_subject_train_val(dataset, 7, train_percentage)
dataset_subject8_train, dataset_subject8_val = separate_dataset_per_subject_train_val(dataset, 8, train_percentage)

for label, ch1, ch2, ch3, ch4, ch5, ch6, ch7, ch8, ch9, ch10, ch11, ch12, ch13, ch14, ch15, ch16, subject, burst in dataset_subject1_train.take(
        15):
    print('label={}, subject={}, burst={}'.format(label, subject, burst))

# print(countRecords(dataset_subject1_train))
all_subject_datasets_train = [dataset_subject1_train,
                              dataset_subject2_train,
                              dataset_subject3_train,
                              dataset_subject4_train,
                              dataset_subject5_train,
                              dataset_subject6_train,
                              dataset_subject7_train,
                              dataset_subject8_train]

all_subject_datasets_val = [dataset_subject1_val,
                            dataset_subject2_val,
                            dataset_subject3_val,
                            dataset_subject4_val,
                            dataset_subject5_val,
                            dataset_subject6_val,
                            dataset_subject7_val,
                            dataset_subject8_val]


def augment_datasets(collection_datasets, tf_record_name):
    with tf.io.TFRecordWriter(tf_record_name) as tfrecord:
        for d in collection_datasets:
            for label, ch1, ch2, ch3, ch4, ch5, ch6, ch7, ch8, ch9, ch10, ch11, ch12, ch13, ch14, ch15, ch16, subject, burst in d:
                #         print('label={}, subject={}, burst={}'.format(label,subject,burst))
                for i in range(16):
                    feature = eval(eval('"ch" + str(i+1)'))
                    features = {
                        'label': tf.train.Feature(float_list=tf.train.FloatList(value=np.asarray(label))),
                        'feature': tf.train.Feature(float_list=tf.train.FloatList(value=np.asarray(feature))),
                        'subject': tf.train.Feature(int64_list=tf.train.Int64List(value=np.asarray([subject]))),
                        'burst': tf.train.Feature(int64_list=tf.train.Int64List(value=np.asarray([burst]))),
                        'channel': tf.train.Feature(int64_list=tf.train.Int64List(value=np.asarray([i + 1])))
                    }
                    example = tf.train.Example(features=tf.train.Features(feature=features))
                    tfrecord.write(example.SerializeToString())
    return


def map_fn_final(serialized_example):
    features = {
        'label': tf.io.FixedLenFeature([4], tf.float32),
        'feature': tf.io.FixedLenFeature([500], tf.float32),
        'subject': tf.io.FixedLenFeature([], tf.int64),
        'burst': tf.io.FixedLenFeature([], tf.int64),
        'channel': tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(serialized_example, features)
    return example['label'], example['feature'], example['subject'], example['burst'], example['channel']


augment_datasets(all_subject_datasets_train, 'all_mixed_train.tfrecord')
augment_datasets(all_subject_datasets_val, 'all_mixed_val.tfrecord')


def mix_and_shuffle_datasets(tf_record_name):
    dataset = tf.data.TFRecordDataset(tf_record_name)
    dataset = dataset.map(map_fn_final)
    dataset_samples = countRecords(dataset)
    dataset_final = dataset.shuffle(dataset_samples)
    return dataset_final


dataset_final_train = mix_and_shuffle_datasets('all_mixed_train.tfrecord')
dataset_final_val = mix_and_shuffle_datasets('all_mixed_val.tfrecord')
for label, feature, subject, burst, channel in dataset_final_train.take(20):
    print(
        'label={}, feature={}, subject={}, burst={}, channel={}'.format(label, feature.shape, subject, burst, channel))


## Load TFRecords
def read_tfrecord(serialized_example, export_subject=False):
    tfrecord_format = (
        {
            'label': tf.io.FixedLenFeature([4], tf.float32),
            'feature': tf.io.FixedLenFeature([500], tf.float32),
            'subject': tf.io.FixedLenFeature([], tf.int64),
            'burst': tf.io.FixedLenFeature([], tf.int64),
            'channel': tf.io.FixedLenFeature([], tf.int64)
        }
    )
    example = tf.io.parse_single_example(serialized_example, tfrecord_format)
    f = tf.reshape(example['feature'], [window, 1])
    f.set_shape([window, 1])
    if export_subject == True:
        return f, example['label'], example['subject']
    return f, example['label']


AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 1024
window = 500


def get_dataset(tf_record_name):
    #     dataset = load_dataset(filename)
    dataset = tf.data.TFRecordDataset(tf_record_name)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    dataset_samples = countRecords(dataset)
    #     print("Samples: ", dataset_samples)
    dataset = dataset.shuffle(dataset_samples)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset


train_dataset = get_dataset('all_mixed_train.tfrecord')
valid_dataset = get_dataset('all_mixed_val.tfrecord')
for feature, label in loo_dataset:
    print('label={}, feature={}'.format(label.shape, feature.shape))


## Plotting Muscle Bursts
# quick plot to see individual contraction bursts
def plot_independent_bursts(label, burst_list, number_plots, color, fixed):
    if number_plots > len(burst_list):
        number_plots = len(burst_list)
    # define plot array settings: 
    n_cols = 4
    n_rows = math.ceil(number_plots / n_cols)
    # plot: n_rows x n_cols of subplots
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(18, n_rows * 2), dpi=150)
    fig.suptitle('Contraction Bursts: {} ({} length)'.format(label, fixed))
    n_burst = 0
    for i in range(n_rows):
        for j in range(n_cols):
            if n_burst > number_plots - 1:
                break
            if n_rows == 1:  # required for the exception of only having 1 row of subplots
                axs[j].plot(burst_list[n_burst * 16][:, 0], color=color)
                axs[j].set_title('Burst {}'.format(n_burst + 1))
                n_burst += 1
            else:  # when having multiple rows of subplots
                axs[i, j].plot(burst_list[n_burst * 16][:, 0], color=color)
                axs[i, j].set_title('Burst {}'.format(n_burst + 1))
                n_burst += 1

    for ax in axs.flat:
        ax.set(xlabel='time (ms)', ylabel='EMG')
        #ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    #Hide x labels and tick labels for top plots and y ticks for right plots.
    for ax in axs.flat:
        ax.label_outer()


plot_independent_bursts("Left Biceps", fixed_bursts_left_arm, 10, "cornflowerblue", "Fixed")
plot_independent_bursts("Right Biceps", fixed_bursts_right_arm, 10, "yellowgreen", "Fixed")
plot_independent_bursts("Left Tibialis Ant.", fixed_bursts_left_leg, 10, "orange", "Fixed")
plot_independent_bursts("Right Tibialis Ant.", fixed_bursts_right_leg, 10, "purple", "Fixed")