In [1]:
import numpy as np
import pandas as pd
from scipy import signal
from scipy.signal import butter, filtfilt, welch


In [2]:
all_data = np.load('processed_data/all_data.npy')
all_labels = np.load('processed_data/all_labels.npy')
all_data = all_data[:,1:,:]

In [3]:
all_data.shape

(240, 4, 1400)

# Split Data

In [4]:
from sklearn.model_selection import train_test_split

In [5]:
X_train, X_test, y_train, y_test = train_test_split(all_data, all_labels, test_size=0.2, random_state=42)

In [6]:
sr = 1000

# Make Features

In [7]:
def bandpassFilter(data, sr, lowcut, highcut):
    nyq = 0.5 * sr
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(2, [low, high], btype='band')
    return filtfilt(b, a, data)

In [8]:
def WL(data):
    wl = np.sum(np.abs(np.diff(data)))
    return wl / len(data)

In [9]:
def MAV(data):
    return np.sum(np.abs(data))/len(data)

In [10]:
def SSC(data,threshold):
    res = 0
    for i in range(1, len(data)-1):
        curr = (data[i]-data[i-1]) * (data[i+1]-data[i])
        if curr >= threshold:
            res += 1
    return res

In [11]:
def ZC(data):
    """
    Counts how many times the signal crosses zero. 
    """
    res = 0
    for i in range(1, len(data)):
        curr = data[i] * data[i-1]
        if curr < 0:
            res += 1
    return res

In [12]:
def VAR(data):
    """
    Measures the spread of signal values around the mean. It reflects the level of muscle activation.
    """
    return np.var(data)

In [13]:
def RMS(data):
    """
    Measures the amplitude of the signal.
    """
    return np.sqrt(np.mean(data**2))

In [14]:
def MF(data, sr):
    f, Pxx = welch(data, sr, nperseg=1024)
    cumulative_power = np.cumsum(Pxx)
    total_power = np.sum(Pxx)
    median_freq = np.interp(total_power/2, cumulative_power, f)
    return median_freq

In [15]:
def feature_extraction(trials):
    num_trials, num_channels, _ = trials.shape  # Get the dimensions of trials
    features = np.empty((num_trials, num_channels * 7))  # Initialize an empty array to store the features

    for j in range(num_channels):
        tmp = []
        for i in range(num_trials):
            slice = trials[i, j, :]  # Access the data for the current trial and channel
            # slice = bandpassFilter(slice, sr, 20, 450)
            wl = WL(slice)
            mav = MAV(slice)
            ssc = SSC(slice, 0.001)
            zc = ZC(slice)
            var = VAR(slice)
            rms = RMS(slice)
            mf = MF(slice, sr)
            features[i, j*7:(j+1)*7] = [wl, mav, ssc, zc, var, rms, mf]  # Store the features in the correct location
    return np.array(features)

In [16]:
train_features = feature_extraction(X_train)

In [17]:
X_train.shape

(192, 4, 1400)

In [18]:
train_features.shape

(192, 28)

In [19]:
test_features = feature_extraction(X_test)

In [20]:
columns = [
    'ch1_wl', 'ch1_mav', 'ch1_ssc', 'ch1_zc', 'ch1_var', 'ch1_rms', 'ch1_mf',
    'ch2_wl', 'ch2_mav', 'ch2_ssc', 'ch2_zc', 'ch2_var', 'ch2_rms', 'ch2_mf',
    'ch3_wl', 'ch3_mav', 'ch3_ssc', 'ch3_zc', 'ch3_var', 'ch3_rms', 'ch3_mf',
    'ch4_wl', 'ch4_mav', 'ch4_ssc', 'ch4_zc', 'ch4_var', 'ch4_rms', 'ch4_mf',
]


In [21]:
df_train = pd.DataFrame(train_features, columns=columns)

In [22]:
df_train

Unnamed: 0,ch1_wl,ch1_mav,ch1_ssc,ch1_zc,ch1_var,ch1_rms,ch1_mf,ch2_wl,ch2_mav,ch2_ssc,...,ch3_var,ch3_rms,ch3_mf,ch4_wl,ch4_mav,ch4_ssc,ch4_zc,ch4_var,ch4_rms,ch4_mf
0,740.638187,1201.952173,976.0,278.0,5.724697e+06,2392.634201,71.204258,593.939302,968.011267,990.0,...,1.884613e+06,1372.812186,68.215072,1068.742965,1949.836649,923.0,334.0,2.237983e+07,4730.732130,59.569130
1,5671.605541,12181.245023,995.0,302.0,5.035037e+09,70957.991115,68.428012,5235.952267,11917.906344,1021.0,...,6.609350e+08,25708.656200,77.708716,3637.289592,7724.444659,958.0,301.0,1.603305e+09,40041.295013,67.814069
2,9031.120862,18861.028470,1011.0,256.0,3.880166e+09,62290.978564,63.879842,7658.328576,16311.040929,1014.0,...,1.435144e+08,11979.748666,88.955221,6755.266037,13839.458082,1002.0,249.0,1.563388e+09,39539.699067,63.388043
3,28209.085414,59919.058845,990.0,241.0,7.071888e+10,265930.214645,63.810770,24569.511530,49755.048958,927.0,...,4.803823e+09,69309.618598,58.747995,12266.519377,22125.748594,915.0,329.0,4.492292e+09,67024.563582,65.632988
4,11910.399561,25642.306364,958.0,293.0,6.795379e+09,82434.090302,63.635049,10139.537940,22531.455850,954.0,...,7.555628e+08,27487.503413,84.912082,5082.531298,10611.109352,990.0,267.0,1.346180e+09,36690.331948,63.290557
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,10134.951745,21144.473304,977.0,275.0,6.966227e+09,83463.925677,63.545906,9566.290131,20402.380781,959.0,...,3.830812e+07,6189.355404,99.669223,2516.340281,4311.837758,955.0,301.0,1.099478e+08,10485.600408,87.260755
188,1566.485682,2444.302783,1002.0,274.0,1.588353e+07,3985.415518,100.056176,1341.003895,2167.125889,1007.0,...,8.029841e+06,2833.697566,99.081033,1082.139986,1750.215206,981.0,286.0,9.525803e+06,3086.393421,81.059171
189,1868.660016,3285.056366,944.0,302.0,1.365721e+08,11686.407793,86.912601,1799.625139,3226.150340,938.0,...,6.489029e+06,2547.358492,95.510320,1838.142767,3793.496204,940.0,291.0,2.106362e+08,14513.311974,80.379906
190,2916.440546,4937.063244,973.0,279.0,1.862599e+08,13647.708412,77.465730,1964.599133,3155.482541,945.0,...,4.001836e+07,6326.006828,77.907225,1562.841559,2428.066583,954.0,291.0,3.523956e+07,5936.292228,81.840659


In [23]:
y_train.shape

(192,)

In [24]:
df_train['label'] = [int(i-1) for i in y_train]

In [25]:
df_test = pd.DataFrame(test_features, columns=columns)
df_test['label'] = [int(i-1) for i in y_test]

In [26]:
df_train.to_csv('processed_data/feature_train.csv', index=False)
df_test.to_csv('processed_data/feature_test.csv', index=False)

In [27]:
df_test

Unnamed: 0,ch1_wl,ch1_mav,ch1_ssc,ch1_zc,ch1_var,ch1_rms,ch1_mf,ch2_wl,ch2_mav,ch2_ssc,...,ch3_rms,ch3_mf,ch4_wl,ch4_mav,ch4_ssc,ch4_zc,ch4_var,ch4_rms,ch4_mf,label
0,10770.351735,19637.774288,1018.0,293.0,10030480000.0,100152.276549,88.463741,10302.135195,17994.957686,1026.0,...,26535.405618,68.844751,4003.314808,7627.015501,1016.0,286.0,932993700.0,30544.946345,72.311498,0
1,1801.766396,2878.945665,1024.0,272.0,17351510.0,4165.518955,88.663917,1370.25892,2213.715793,1025.0,...,2950.609467,88.85136,1257.332613,2028.096124,1019.0,275.0,8927063.0,2987.824631,89.193663,2
2,6690.791905,14707.538869,1008.0,280.0,7444250000.0,86280.069046,67.362492,6425.905561,14239.986812,990.0,...,2505.406892,85.503126,2151.900187,3810.956446,1041.0,268.0,84226990.0,9177.526448,74.304328,2
3,1106.550094,1850.630112,1002.0,278.0,10544800.0,3247.276307,83.943416,867.713132,1461.749498,1010.0,...,2235.540675,87.144596,1723.475907,3373.041824,1022.0,258.0,131114500.0,11450.52283,78.86748,0
4,1673.494733,2786.142018,921.0,313.0,33910090.0,5823.23696,94.440817,1265.729638,2050.197477,884.0,...,3097.217415,98.364642,1680.49449,2897.889945,900.0,330.0,44515300.0,6671.988258,86.197539,2
5,1061.194992,1705.134804,954.0,296.0,8228504.0,2868.536934,75.243157,702.327139,1045.905996,952.0,...,1683.268782,78.020172,752.98391,1218.005467,963.0,311.0,4415086.0,2101.21125,76.419812,2
6,3853.626772,5641.898338,975.0,289.0,166539300.0,12905.010173,86.81551,4674.574298,7082.472295,1022.0,...,10465.402623,93.292779,3642.144729,5612.263437,1002.0,300.0,206020600.0,14353.418983,92.852298,1
7,12868.817775,31433.524907,998.0,258.0,19754770000.0,140551.66776,59.20401,2324.031921,4437.951871,1017.0,...,138514.345593,58.877653,12970.624886,32872.416893,971.0,293.0,15003450000.0,122488.585557,53.884763,0
8,1715.162749,2752.746783,995.0,288.0,17889940.0,4229.665952,86.965741,1511.60163,2538.442978,1025.0,...,2991.634523,91.908588,1302.736293,2227.992105,1022.0,271.0,11446910.0,3383.333135,82.559323,2
9,1462.543319,2327.552358,961.0,298.0,47410330.0,6885.51612,86.449466,1440.344636,2264.188178,938.0,...,2337.284192,89.156829,1511.266287,2402.751115,912.0,326.0,69694390.0,8348.317117,97.430668,1
