In [10]:
import numpy as np
import pandas as pd
from scipy import io as sio

import warnings

warnings.filterwarnings('ignore')

### Load the Data

In [2]:
data = sio.loadmat('Project_data.mat')
channels = data['Channels']
test_data = data['TestData']
train_data = data['TrainData']
train_labels = data['TrainLabels'].ravel()
fs = data['fs']

In [3]:
train_labels.shape, train_data.shape, test_data.shape

((550,), (59, 5000, 550), (59, 5000, 159))

### Functions to Extract Features 

In [4]:
from sklearn.linear_model import LinearRegression
from scipy.signal import welch
from scipy.integrate import simps

def calc_variance(data):
    return np.std(data, axis=1).T

def calc_amp_hist(data, num_bins):
    num_channels = data.shape[0]
    num_trials = data.shape[2]
    features = np.zeros((num_trials, num_channels*(num_bins+1)))
    for i in range(num_trials):
        for j in range(num_channels):
            features[i, j*(num_bins+1):(j+1)*(num_bins+1)] = np.histogram(data[j, :, i], num_bins)[1]

    return features

def calc_ar(data, model_size, fit_intercept):
    num_channels = data.shape[0]
    num_trials = data.shape[2]
    signal_length = data.shape[1]

    features = np.zeros((num_trials, num_channels*(model_size + (lambda x: 1 if x else 0)(fit_intercept))))
    for i in range(num_trials):
        for j in range(num_channels):
            signal = data[j, :, i]

            feature_mat = np.zeros((signal_length-model_size, model_size))
            for k in range(model_size, signal_length):
                feature_mat[k-model_size, :] = signal[k-model_size:k]

            lr = LinearRegression(fit_intercept=fit_intercept)
            lr.fit(feature_mat, signal[model_size:])

            if fit_intercept:
                features[i, j*(model_size+1):(j+1)*(model_size+1)] = np.insert(lr.coef_, 0, lr.intercept_)
            else:
                features[i, j*(model_size):(j+1)*(model_size)] = lr.coef_

    return features

def calc_correlation(data):
    num_channels = data.shape[0]
    num_trials = data.shape[2]

    features = np.zeros((num_trials, num_channels**2))
    for i in range(num_trials):
        for j in range(num_channels):
            for k in range(num_channels):
                mean1 = np.mean(data[j, :, i])
                mean2 = np.mean(data[k, :, i])

                features[i, j*num_channels + k] = np.mean((data[j, :, i] - mean1)*(data[k, :, i])-mean2)

    return features

def calc_max_freq(data, fs): 
    num_channels = data.shape[0]
    num_trials = data.shape[2]

    features = np.zeros((num_trials, num_channels))
    for i in range(num_trials):
        for j in range(num_channels):      
            frequencies, psd = welch(data[j, :, i], fs=fs, nperseg=2048)
            frequencies = frequencies.ravel()
            features[i, j] = frequencies[np.argmax(psd)]

    return features

def calc_mean_freq(data, fs): 
    num_channels = data.shape[0]
    num_trials = data.shape[2]

    features = np.zeros((num_trials, num_channels))
    for i in range(num_trials):
        for j in range(num_channels):      
            frequencies, psd = welch(data[j, :, i], fs=fs, nperseg=2048)
            frequencies = frequencies.ravel()
            features[i, j] = np.sum(frequencies * psd) / np.sum(psd)

    return features

def calc_median_freq(data, fs): 
    num_channels = data.shape[0]
    num_trials = data.shape[2]

    features = np.zeros((num_trials, num_channels))
    for i in range(num_trials):
        for j in range(num_channels):      
            frequencies, psd = welch(data[j, :, i], fs=fs, nperseg=2048)
            frequencies = frequencies.ravel()
            cumulative_psd = np.cumsum(psd)
            median_index = np.where(cumulative_psd >= cumulative_psd[-1] / 2)[0][0]

            features[i, j] = frequencies[median_index]

    return features

def calc_rel_energy(data, fs, bands):
    num_channels = data.shape[0]
    num_trials = data.shape[2]
    
    features = np.zeros((num_trials, num_channels*len(bands.keys())))
    for i in range(num_trials):
        for j in range(num_channels):
            frequencies, psd = welch(data[j, :, i], fs=fs, nperseg=2048)
            frequencies = frequencies.ravel()

            total_energy = simps(psd, frequencies)
            
            for k, (band, (low, high)) in enumerate(bands.items()):
                idx_band = np.logical_and(frequencies >= low, frequencies <= high)
                band_energy = simps(psd[idx_band], frequencies[idx_band], axis=0)
                relative_band_energy = band_energy / total_energy

                features[i, j*len(bands.keys()) + k] = relative_band_energy

    return features

### Feature Extraction

In [79]:
ar_no_intercept = calc_ar(train_data, 10, False) #Takes lots of time to run!! Load the pre-calculated data instead!
np.save('ar_no_intercept', ar_no_intercept)

In [87]:
cross_corr = calc_correlation(train_data) #Takes lots of time to run!! Load the pre-calculated data instead!
np.save('cross_corr', cross_corr)

#### Time Features

In [8]:
var = calc_variance(train_data)
amp_hist = calc_amp_hist(train_data, 10)
ar_model = np.load('ar_no_intercept.npy')
cross_corr = np.load('cross_corr.npy')

In [15]:
var_test = calc_variance(test_data)
amp_hist_test = calc_amp_hist(test_data, 10)
ar_model_test = calc_ar(test_data, 10, False)
cross_corr_test = calc_correlation(test_data)

In [12]:
var.shape, amp_hist.shape, ar_model.shape, cross_corr.shape

((550, 59), (550, 649), (550, 590), (550, 3481))

#### Frequency Features

In [11]:
bands = {
    'Delta': (0.1, 4),
    'Theta': (4, 8),
    'Alpha': (8, 12),
    'Low-Range Beta': (12, 16),
    'Mid-Range Beta': (16, 21),
    'High-Range Beta': (21, 30),
    'Gamma': (30, 500)
}

max_freq = calc_max_freq(train_data, fs)
mean_freq = calc_mean_freq(train_data, fs)
med_freq = calc_median_freq(train_data, fs)
rel_energy = calc_rel_energy(train_data, fs, bands)

In [16]:
max_freq_test = calc_max_freq(test_data, fs)
mean_freq_test = calc_mean_freq(test_data, fs)
med_freq_test = calc_median_freq(test_data, fs)
rel_energy_test = calc_rel_energy(test_data, fs, bands)

In [13]:
max_freq.shape, mean_freq.shape, med_freq.shape, rel_energy.shape

((550, 59), (550, 59), (550, 59), (550, 413))

#### Aggregate and Save

In [18]:
np.save('var', var)
np.save('amp_hist', amp_hist)
np.save('ar_model', ar_model)
np.save('cross_corr', cross_corr)
np.save('max_freq', max_freq)
np.save('mean_freq', mean_freq)
np.save('med_freq', med_freq)
np.save('rel_energy', rel_energy)

In [19]:
np.save('var_test', var_test)
np.save('amp_hist_test', amp_hist_test)
np.save('ar_model_test', ar_model_test)
np.save('cross_corr_test', cross_corr_test)
np.save('max_freq_test', max_freq_test)
np.save('mean_freq_test', mean_freq_test)
np.save('med_freq_test', med_freq_test)
np.save('rel_energy_test', rel_energy_test)

In [None]:
features_tr = np.concatenate()