In [1]:
import numpy as np
import pandas as pd
from scipy import signal
from scipy.signal import butter, filtfilt, welch


In [2]:
all_data = np.load('processed_data/onset_data.npy')
all_labels = np.load('processed_data/onset_labels.npy')


# Split Data

In [3]:
from sklearn.model_selection import train_test_split

In [4]:
X_train, X_test, y_train, y_test = train_test_split(all_data, all_labels, test_size=0.2, random_state=42)

In [5]:
sr = 1000

# Make Features

In [7]:
def WL(data):
    wl = np.sum(np.abs(np.diff(data)))
    return wl / len(data)

In [8]:
def MAV(data):
    return np.sum(np.abs(data))/len(data)

In [9]:
def SSC(data,threshold):
    res = 0
    for i in range(1, len(data)-1):
        curr = (data[i]-data[i-1]) * (data[i+1]-data[i])
        if curr >= threshold:
            res += 1
    return res

In [10]:
def ZC(data):
    """
    Counts how many times the signal crosses zero. 
    """
    res = 0
    for i in range(1, len(data)):
        curr = data[i] * data[i-1]
        if curr < 0:
            res += 1
    return res

In [11]:
def VAR(data):
    """
    Measures the spread of signal values around the mean. It reflects the level of muscle activation.
    """
    return np.var(data)

In [12]:
def RMS(data):
    """
    Measures the amplitude of the signal.
    """
    return np.sqrt(np.mean(data**2))

In [13]:
def MF(data, sr):
    f, Pxx = welch(data, sr, nperseg=1024)
    cumulative_power = np.cumsum(Pxx)
    total_power = np.sum(Pxx)
    median_freq = np.interp(total_power/2, cumulative_power, f)
    return median_freq

In [14]:
def PF(data, fs):
    freqs, psd = welch(data, fs)
    peak_freq = freqs[np.argmax(psd)]

    return peak_freq

In [15]:
def calculate_hjorth_parameters(data):
    # Activity is the signal variance
    activity = np.var(data)

    # Mobility is the square root of the variance of the first derivative of the signal
    # divided by the activity
    mobility = np.sqrt(np.var(np.diff(data)) / activity)

    # Complexity is the mobility of the first derivative of the signal divided by the mobility
    complexity = np.sqrt(np.var(np.diff(data, n=2)) / np.var(np.diff(data)))

    return activity, mobility, complexity


In [16]:
def feature_extraction(trials):
    num_trials, num_channels, _ = trials.shape  # Get the dimensions of trials
    features = np.empty((num_trials, num_channels * 11))  # Initialize an empty array to store the features

    for j in range(num_channels):
        tmp = []
        for i in range(num_trials):
            slice = trials[i, j, :]  # Access the data for the current trial and channel
            # slice = bandpassFilter(slice, sr, 20, 450)
            wl = WL(slice)
            mav = MAV(slice)
            ssc = SSC(slice, 0.001)
            zc = ZC(slice)
            var = VAR(slice)
            rms = RMS(slice)
            mf = MF(slice, sr)
            pf = PF(slice, sr)
            activity, mobility, complexity = calculate_hjorth_parameters(slice)
            features[i, j*11:(j+1)*11] = [wl, mav, ssc, zc, var, rms, mf, pf, activity, mobility, complexity]  # Store the features in the correct location
    return np.array(features)

In [17]:
train_features = feature_extraction(X_train)



In [18]:
X_train.shape

(1531, 4, 600)

In [19]:
train_features.shape

(1531, 44)

In [20]:
test_features = feature_extraction(X_test)

In [28]:
columns = [
    'ch1_wl', 'ch1_mav', 'ch1_ssc', 'ch1_zc', 'ch1_var', 'ch1_rms', 'ch1_mf','ch1_pf','ch1_activity','ch1_mobility','ch1_complexity',
    'ch2_wl', 'ch2_mav', 'ch2_ssc', 'ch2_zc', 'ch2_var', 'ch2_rms', 'ch2_mf', 'ch2_pf','ch2_activity','ch2_mobility','ch2_complexity',
    'ch3_wl', 'ch3_mav', 'ch3_ssc', 'ch3_zc', 'ch3_var', 'ch3_rms', 'ch3_mf', 'ch3_pf','ch3_activity','ch3_mobility','ch3_complexity',
    'ch4_wl', 'ch4_mav', 'ch4_ssc', 'ch4_zc', 'ch4_var', 'ch4_rms', 'ch4_mf', 'ch4_pf','ch4_activity','ch4_mobility','ch4_complexity'
]


In [29]:
ch4_wl
ch3_wl
ch4_mobility
ch1_wl
ch2_wl
ch4_mf
ch3_mobility
ch2_mobility
ch4_mav
ch2_mf
ch2_complexity
ch4_zc
ch4_rms
ch3_mav
ch3_zc
ch1_zc
ch4_activity
ch2_mav
ch3_complexity
ch2_zc


'ch4_wl'

In [22]:
df_train = pd.DataFrame(train_features, columns=columns)

In [23]:
df_train

Unnamed: 0,ch1_wl,ch1_mav,ch1_ssc,ch1_zc,ch1_var,ch1_rms,ch1_mf,ch1_pf,ch1_activity,ch1_mobility,...,ch4_mav,ch4_ssc,ch4_zc,ch4_var,ch4_rms,ch4_mf,ch4_pf,ch4_activity,ch4_mobility,ch4_complexity
0,1739.706073,1.187184e+04,464.0,19.0,1.992348e+08,1.414716e+04,4.962568,3.90625,1.992348e+08,0.213198,...,1.349651e+04,429.0,21.0,4.491583e+08,2.120997e+04,5.467278,3.90625,4.491583e+08,0.090306,0.827227
1,739.059805,2.271263e+03,420.0,73.0,1.149516e+07,3.390451e+03,9.661960,7.81250,1.149516e+07,0.339176,...,2.173862e+04,513.0,7.0,1.123632e+09,3.352298e+04,6.632033,7.81250,1.123632e+09,0.057228,0.303900
2,6175.906947,1.003054e+05,516.0,9.0,2.218704e+10,1.489533e+05,6.198776,7.81250,2.218704e+10,0.075741,...,5.443400e+04,462.0,9.0,5.849121e+09,7.647956e+04,5.874940,7.81250,5.849121e+09,0.100495,0.714125
3,2688.654798,6.472630e+03,435.0,75.0,1.056357e+08,1.027802e+04,40.901303,11.71875,1.056357e+08,0.406872,...,3.947034e+03,457.0,65.0,3.818892e+07,6.179783e+03,44.922286,11.71875,3.818892e+07,0.355314,0.704312
4,1185.008108,6.885047e+03,388.0,32.0,1.118447e+08,1.057587e+04,6.912910,7.81250,1.118447e+08,0.157153,...,1.299138e+04,440.0,21.0,4.421169e+08,2.102669e+04,10.231913,11.71875,4.421169e+08,0.104158,0.518852
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1526,1796.277009,5.057801e+03,434.0,67.0,1.118495e+08,1.057606e+04,42.510561,42.96875,1.118495e+08,0.364288,...,4.006269e+04,509.0,7.0,4.096764e+09,6.400671e+04,7.012531,7.81250,4.096764e+09,0.067903,0.387551
1527,1031.264543,1.817151e+03,407.0,102.0,9.738915e+06,3.120724e+03,60.843793,50.78125,9.738915e+06,0.571664,...,1.028683e+03,404.0,107.0,2.827710e+06,1.681648e+03,73.213562,50.78125,2.827710e+06,0.635701,0.936359
1528,1291.878523,3.418196e+03,404.0,84.0,2.741126e+07,5.235577e+03,46.635796,7.81250,2.741126e+07,0.462522,...,2.511381e+04,476.0,7.0,1.549183e+09,3.936282e+04,6.937635,7.81250,1.549183e+09,0.070541,0.523892
1529,104110.168431,1.061181e+06,521.0,20.0,3.044833e+12,1.744945e+06,11.645963,11.71875,3.044833e+12,0.130253,...,1.424583e+06,533.0,12.0,4.154624e+12,2.038362e+06,7.242158,7.81250,4.154624e+12,0.092735,0.444171


In [24]:
df_train['label'] = y_train

In [25]:
df_test = pd.DataFrame(test_features, columns=columns)
df_test['label'] = y_test

In [26]:
df_train.to_csv('processed_data/onset_feature_train.csv', index=False)
df_test.to_csv('processed_data/onset_feature_test.csv', index=False)

In [27]:
df_test

Unnamed: 0,ch1_wl,ch1_mav,ch1_ssc,ch1_zc,ch1_var,ch1_rms,ch1_mf,ch1_pf,ch1_activity,ch1_mobility,...,ch4_ssc,ch4_zc,ch4_var,ch4_rms,ch4_mf,ch4_pf,ch4_activity,ch4_mobility,ch4_complexity,label
0,2386.783601,5600.546818,433.0,81.0,6.872393e+07,8290.507567,47.078788,46.87500,6.872393e+07,0.442210,...,427.0,84.0,2.084214e+07,4565.417833,61.764499,70.31250,2.084214e+07,0.460152,0.834889,0.0
1,832.885496,8605.653124,435.0,24.0,2.975860e+08,17250.958310,11.326012,11.71875,2.975860e+08,0.103601,...,520.0,12.0,1.194032e+10,109271.781182,6.565698,3.90625,1.194032e+10,0.072656,0.359401,0.0
2,950.555069,1955.608594,409.0,103.0,9.388528e+06,3064.210964,51.074170,93.75000,9.388528e+06,0.498633,...,413.0,94.0,2.663890e+06,1632.231388,78.364150,70.31250,2.663890e+06,0.559784,0.835620,1.0
3,2716.108502,9481.407371,394.0,63.0,2.146453e+08,14650.829506,7.730966,7.81250,2.146453e+08,0.256069,...,404.0,119.0,2.036872e+07,4513.613028,72.443085,31.25000,2.036872e+07,0.682159,1.089617,1.0
4,1167.265823,1943.054171,404.0,110.0,1.046680e+07,3235.839865,91.184709,109.37500,1.046680e+07,0.612405,...,397.0,120.0,5.764866e+06,2401.043987,80.820067,82.03125,5.764866e+06,0.582615,0.818306,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
378,1403.199469,3673.918562,384.0,71.0,3.297804e+07,5742.939019,9.965431,7.81250,3.297804e+07,0.377445,...,376.0,119.0,1.158616e+07,3403.857189,90.326648,39.06250,1.158616e+07,0.743700,1.112615,2.0
379,297.360579,806.621353,389.0,85.0,1.716734e+06,1310.718266,12.083048,7.81250,1.716734e+06,0.414208,...,354.0,37.0,1.708291e+06,1310.221126,4.502221,3.90625,1.708291e+06,0.166315,1.147723,0.0
380,3359.466468,5111.155341,369.0,136.0,4.887819e+07,6991.511706,88.473510,93.75000,4.887819e+07,0.646839,...,392.0,99.0,4.253514e+07,6522.103331,70.670203,7.81250,4.253514e+07,0.497182,0.938617,1.0
381,1845.952530,5463.807064,433.0,59.0,6.208474e+07,7881.997575,10.162589,7.81250,6.208474e+07,0.389076,...,423.0,18.0,2.259727e+08,15032.618666,7.850806,7.81250,2.259727e+08,0.135216,0.836547,0.0
