In [1]:
import numpy as np
import pandas as pd
from scipy import signal
from scipy.signal import butter, filtfilt, welch
from sklearn.preprocessing import StandardScaler


In [2]:
all_data = np.load('processed_data/tune_onset_data.npy')
all_labels = np.load('processed_data/tune_labels.npy')


# Split Data

In [3]:
from sklearn.model_selection import train_test_split

In [4]:
X_train, X_test, y_train, y_test = train_test_split(all_data, all_labels, test_size=0.2, random_state=42)

In [5]:
sr = 1000

# Make Features
Waveform Length (WL): Calculates the cumulative length of the waveform, representing the complexity of the signal.

Mean Absolute Value (MAV): Computes the average absolute value, indicating the signal’s energy.

Slope Sign Changes (SSC): Counts the number of times the signal's slope changes, which helps in identifying muscle contraction patterns.

Zero Crossing (ZC): Measures how many times the signal crosses zero, useful for identifying signal frequency content.

Variance (VAR): Determines the spread of signal values around the mean, reflecting the level of muscle activation.

Root Mean Square (RMS): Calculates the signal's amplitude, which is indicative of muscle strength.

Median Frequency (MF) and Peak Frequency (PF): These functions use the Welch method to calculate the median and peak frequencies of the signal, giving insights into the dominant frequencies.

Hjorth Parameters: This function computes three statistical parameters – activity, mobility, and complexity – which provide information about the signal's characteristics such as variability and waveform shape.

Band Power (BP): Calculates the power of the signal within a specific frequency band, important for understanding the signal’s energy distribution.


In [7]:
def WL(data):
    wl = np.sum(np.abs(np.diff(data)))
    return wl / len(data)

In [8]:
def MAV(data):
    return np.sum(np.abs(data))/len(data)

In [9]:
def SSC(data,threshold):
    res = 0
    for i in range(1, len(data)-1):
        curr = (data[i]-data[i-1]) * (data[i+1]-data[i])
        if curr >= threshold:
            res += 1
    return res

In [10]:
def ZC(data):
    """
    Counts how many times the signal crosses zero. 
    """
    res = 0
    for i in range(1, len(data)):
        curr = data[i] * data[i-1]
        if curr < 0:
            res += 1
    return res

In [11]:
def VAR(data):
    """
    Measures the spread of signal values around the mean. It reflects the level of muscle activation.
    """
    return np.var(data)

In [12]:
def RMS(data):
    """
    Measures the amplitude of the signal.
    """
    return np.sqrt(np.mean(data**2))

In [13]:
def MF(data, sr):
    f, Pxx = welch(data, sr, nperseg=1024)
    cumulative_power = np.cumsum(Pxx)
    total_power = np.sum(Pxx)
    median_freq = np.interp(total_power/2, cumulative_power, f)
    return median_freq

In [14]:
def PF(data, fs):
    freqs, psd = welch(data, fs)
    peak_freq = freqs[np.argmax(psd)]

    return peak_freq

In [15]:
def calculate_hjorth_parameters(data):
    # Activity is the signal variance
    activity = np.var(data)

    # Mobility is the square root of the variance of the first derivative of the signal
    # divided by the activity
    mobility = np.sqrt(np.var(np.diff(data)) / activity)

    # Complexity is the mobility of the first derivative of the signal divided by the mobility
    complexity = np.sqrt(np.var(np.diff(data, n=2)) / np.var(np.diff(data)))

    return activity, mobility, complexity


In [16]:
def BP(data, fs, band=(20,450)):
    """Calculates the band power """
    freqs, psd = welch(data, fs, window='hann', nperseg=1024, scaling='density')
    freq_mask = (freqs >= band[0]) & (freqs <= band[1])
    bp = np.trapz(psd[freq_mask], freqs[freq_mask])
    return bp

In [17]:
def feature_extraction(trials):
    num_trials, num_channels, _ = trials.shape  # Get the dimensions of trials
    features = np.empty((num_trials, num_channels * 12))  # Initialize an empty array to store the features

    for j in range(num_channels):
        tmp = []
        for i in range(num_trials):
            slice = trials[i, j, :]  # Access the data for the current trial and channel
            # slice = bandpassFilter(slice, sr, 20, 450)
            wl = WL(slice)
            mav = MAV(slice)
            ssc = SSC(slice, 0.001)
            zc = ZC(slice)
            var = VAR(slice)
            rms = RMS(slice)
            mf = MF(slice, sr)
            pf = PF(slice, sr)
            bp = BP(slice, sr)
            activity, mobility, complexity = calculate_hjorth_parameters(slice)
            features[i, j*12:(j+1)*12] = [wl, mav, ssc, zc, var, rms, mf, pf, activity, mobility, complexity, bp]  # Store the features in the correct location
    return np.array(features)

In [18]:
train_features = feature_extraction(X_train)



In [19]:
test_features = feature_extraction(X_test)

In [20]:
columns = [
    'ch1_wl', 'ch1_mav', 'ch1_ssc', 'ch1_zc', 'ch1_var', 'ch1_rms', 'ch1_mf','ch1_pf','ch1_activity','ch1_mobility','ch1_complexity','ch1_bp',
    'ch2_wl', 'ch2_mav', 'ch2_ssc', 'ch2_zc', 'ch2_var', 'ch2_rms', 'ch2_mf', 'ch2_pf','ch2_activity','ch2_mobility','ch2_complexity', 'ch2_bp',
    'ch3_wl', 'ch3_mav', 'ch3_ssc', 'ch3_zc', 'ch3_var', 'ch3_rms', 'ch3_mf', 'ch3_pf','ch3_activity','ch3_mobility','ch3_complexity', 'ch3_bp',
    'ch4_wl', 'ch4_mav', 'ch4_ssc', 'ch4_zc', 'ch4_var', 'ch4_rms', 'ch4_mf', 'ch4_pf','ch4_activity','ch4_mobility','ch4_complexity'   , 'ch4_bp'
]


In [21]:
df_train = pd.DataFrame(train_features, columns=columns)
df_train['label'] = y_train
df_train.head()

Unnamed: 0,ch1_wl,ch1_mav,ch1_ssc,ch1_zc,ch1_var,ch1_rms,ch1_mf,ch1_pf,ch1_activity,ch1_mobility,...,ch4_zc,ch4_var,ch4_rms,ch4_mf,ch4_pf,ch4_activity,ch4_mobility,ch4_complexity,ch4_bp,label
0,1771.171829,2856.676753,394.0,108.0,20904940.0,4572.193081,93.87822,101.5625,20904940.0,0.658932,...,129.0,5292122.0,2300.48513,70.844823,50.78125,5292122.0,0.673428,1.086598,10664940.0,1.0
1,14897.738439,16303.247034,320.0,229.0,525831100.0,22931.177332,237.728162,238.28125,525831100.0,0.765532,...,164.0,292700300.0,17108.734586,36.681424,62.5,292700300.0,0.386685,0.725502,51108660.0,3.0
2,3345.050764,8883.969835,444.0,99.0,595626400.0,24405.460258,33.513125,31.25,595626400.0,0.289256,...,60.0,5388652000.0,73407.442405,31.045075,35.15625,5388652000.0,0.228126,0.442768,8666772000.0,2.0
3,756.216943,1806.365095,409.0,99.0,9676177.0,3110.659423,37.123395,35.15625,9676177.0,0.397913,...,70.0,38163470.0,6177.65964,36.965889,31.25,38163470.0,0.313002,0.701886,64419440.0,3.0
4,1094.900863,2163.169242,422.0,94.0,10396580.0,3224.373246,45.687026,46.875,10396580.0,0.521728,...,87.0,11096970.0,3331.211497,33.056264,35.15625,11096970.0,0.386112,0.965219,19288290.0,3.0


In [22]:
df_train.to_csv("processed_data/train_features.csv", index=False)

In [23]:
df_test = pd.DataFrame(test_features, columns=columns)
df_test['label'] = y_test
df_test.head()

Unnamed: 0,ch1_wl,ch1_mav,ch1_ssc,ch1_zc,ch1_var,ch1_rms,ch1_mf,ch1_pf,ch1_activity,ch1_mobility,...,ch4_zc,ch4_var,ch4_rms,ch4_mf,ch4_pf,ch4_activity,ch4_mobility,ch4_complexity,ch4_bp,label
0,4171.194908,6184.499425,397.0,148.0,102759900.0,10137.061374,81.991338,82.03125,102759900.0,0.630489,...,125.0,55217090.0,7430.877979,83.431107,82.03125,55217090.0,0.630074,0.890426,113308000.0,2.0
1,1178.499816,2649.727715,437.0,81.0,20289370.0,4504.372486,53.014577,39.0625,20289370.0,0.430095,...,84.0,46105840.0,6790.129163,19.540307,15.625,46105840.0,0.265384,0.709138,54681320.0,2.0
2,1679.917407,2857.902067,434.0,112.0,24057950.0,4904.897462,92.883094,97.65625,24057950.0,0.597522,...,83.0,43741390.0,6613.739756,23.274787,23.4375,43741390.0,0.311012,0.840163,78658840.0,3.0
3,917.576571,2822.41832,445.0,75.0,20929130.0,4574.858025,24.95543,23.4375,20929130.0,0.316443,...,60.0,651020900.0,25515.134072,27.214335,23.4375,651020900.0,0.220201,0.541424,1209753000.0,3.0
4,12116.328103,66908.26939,486.0,72.0,29367930000.0,171370.732776,19.93203,19.53125,29367930000.0,0.18529,...,66.0,8245029000.0,90802.15555,17.901685,19.53125,8245029000.0,0.161798,0.502352,8542103000.0,2.0


In [24]:
df_test.to_csv("processed_data/test_features.csv", index=False)