In [1]:
from pathlib import Path
import numpy as np
import mne
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from scipy import signal

In [2]:
data_path = Path('../') / 'EEG_data' / 'cleaned_data'
sample = data_path / 'sub-001_cleaned-epo.fif'

In [3]:
cleaned_data_file = sample
epochs = mne.read_epochs(cleaned_data_file, preload=True, verbose=False)
raw_cleaned_df = epochs.to_data_frame()
raw_cleaned_df 

channels = ['C3:A2', 'C4:A1', 'O2:A1', 'F4:A1', 'O1:A2', 'F3:A2']
raw_cleaned_df

Unnamed: 0,time,condition,epoch,ECG II,EMG1,EMG2,EMG3,C4:A1,O2:A1,F4:A1,C3:A2,O1:A2,F3:A2,EOG1:A2,EOG2:EOG1,EMG2:EMG3,EMG1:EMG3,EOG2:A1
0,0.000000,A,0,-1.020675e-13,1.588187e-16,3.705769e-16,-7.146840e-16,4.235165e-16,-4.605742e-15,-3.366956e-14,-3.599890e-15,8.470329e-15,-1.778769e-14,-4.743384e-14,6.691560e-14,9.529121e-16,-1.323489e-16,1.355253e-14
1,0.003906,A,0,1.362166e+00,-4.281903e-02,-2.598510e-03,-2.333368e-02,-7.085808e-01,-1.673215e+00,-9.580508e-01,-3.060751e-01,-7.371493e-01,-3.267541e-01,-8.987415e-01,8.471936e-01,2.073517e-02,-1.948534e-02,-4.657830e-01
2,0.007812,A,0,8.771957e-01,-2.639470e-02,1.075816e-04,-1.550780e-02,-4.920361e-01,-1.144155e+00,-6.800739e-01,-1.675360e-01,-4.571859e-01,-2.326080e-01,-4.968694e-01,4.018867e-01,1.561538e-02,-1.088690e-02,-4.467325e-01
3,0.011719,A,0,-8.680287e-01,3.159021e-02,6.716462e-03,1.503837e-02,3.520358e-01,8.742931e-01,4.263973e-01,2.701067e-01,5.157499e-01,1.416501e-01,7.786723e-01,-9.111621e-01,-8.321904e-03,1.655184e-02,-9.678365e-02
4,0.015625,A,0,-1.628463e+00,5.945348e-02,1.084394e-02,3.084874e-02,6.967925e-01,1.699362e+00,8.515547e-01,4.646388e-01,9.441912e-01,2.821637e-01,1.341473e+00,-1.513742e+00,-2.000480e-02,2.860474e-02,-5.038682e-03
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11022230,29.984375,Wake,1434,1.771239e+01,-4.785231e-01,6.029068e-01,-1.089177e+00,-1.632440e+01,-8.294277e+00,-1.967413e+01,-3.331244e+00,7.990916e-01,-1.054227e+01,-3.507020e+01,2.298886e+00,1.692084e+00,6.106545e-01,-3.945805e+01
11022231,29.988281,Wake,1434,4.517182e+00,-3.852543e-01,3.131050e-01,-6.832829e-01,-1.534606e+01,-7.673178e+00,-1.899928e+01,-2.044198e+00,3.213013e-01,-1.031658e+01,-3.244126e+01,3.149710e+00,9.963879e-01,2.980286e-01,-3.535804e+01
11022232,29.992188,Wake,1434,6.045823e-01,-2.618192e-01,2.296895e-01,-6.083463e-01,-6.970833e+00,-5.288079e+00,-3.698727e+00,-4.497857e+00,-3.852600e+00,-1.421433e+01,-3.273737e+01,-3.571930e+00,8.380359e-01,3.465271e-01,-3.600307e+01
11022233,29.996094,Wake,1434,-1.388293e-01,3.571765e-02,5.325797e-01,-3.710119e-01,-6.669784e+00,-5.877498e+00,-3.289416e+00,-2.278697e+00,-3.775646e+00,-1.217103e+01,-2.908987e+01,-1.495797e+00,9.035914e-01,4.067295e-01,-3.126076e+01


In [309]:
bands = {
    "delta": (0.5, 4),
    "theta": (4, 8),
    "alpha": (8, 12),
    "sigma": (12, 15),
    "beta": (15, 30),
}

# def get_band_power(fft_freqs, intensity, band):
#     upper = (np.abs(fft_freqs) <= bands['delta'][1])
#     lower = (np.abs(fft_freqs) >= bands['delta'][0])

#     power = intensity[upper & lower].sum()
#     return power

def get_band_power(fft_freqs, psd, band):
    low_freq, high_freq = bands[band]
    freq_res = fft_freqs[1] - fft_freqs[0]
    band_indices = np.where((fft_freqs >= low_freq) & (fft_freqs <= high_freq))[0]
    power = np.sum(psd[band_indices]) * freq_res
    return power 

In [320]:
def feature_extract_fft(raw_cleaned_df, epoch, channel):
    epoch_filt = raw_cleaned_df["epoch"] == epoch
    s = raw_cleaned_df[epoch_filt][channel]
    t = raw_cleaned_df[epoch_filt]["time"]

    label = raw_cleaned_df[epoch_filt]['condition'].unique()[0]

    d = 30 / (len(t) - 1)
    n = len(t)

    fft = 2 * np.fft.fft(s) / n
    fft_freqs = np.fft.fftfreq(n, d)

    intensity = np.abs(fft) ** 2

    peak_freq = np.abs(fft_freqs[np.argmax(intensity)])
    total_power = intensity.sum()
    delta_power = get_band_power(fft_freqs, intensity, "delta")
    theta_power = get_band_power(fft_freqs, intensity, "theta")
    alpha_power = get_band_power(fft_freqs, intensity, "alpha")
    sigma_power = get_band_power(fft_freqs, intensity, "sigma")
    beta_power = get_band_power(fft_freqs, intensity, "beta")

    

    return {
        "peak_freq": peak_freq,
        "total_power": total_power,
        "delta_power": delta_power,
        "theta_power": theta_power,
        "alpha_power": alpha_power,
        "sigma_power": sigma_power,
        "beta_power": beta_power,
    }, label


def feature_extract_welch(raw_cleaned_df, epoch, channel):
    epoch_filt = raw_cleaned_df["epoch"] == epoch
    s = raw_cleaned_df[epoch_filt][channel].to_numpy()
    t = raw_cleaned_df[epoch_filt]["time"]

    fs = 256
    nperseg = 2 * fs
    # The 'welch' function returns frequencies and power spectral density
    fft_freqs, intensity = signal.welch(s, fs=fs, nperseg=nperseg)

    label = raw_cleaned_df[epoch_filt]['condition'].unique()[0]

    peak_freq = np.abs(fft_freqs[np.argmax(intensity)])
    total_power = intensity.sum()
    delta_power = get_band_power(fft_freqs, intensity, "delta")
    theta_power = get_band_power(fft_freqs, intensity, "theta")
    alpha_power = get_band_power(fft_freqs, intensity, "alpha")
    sigma_power = get_band_power(fft_freqs, intensity, "sigma")
    beta_power = get_band_power(fft_freqs, intensity, "beta")

    return {
        "peak_freq": peak_freq,
        "total_power": total_power,
        "delta_power": delta_power,
        "theta_power": theta_power,
        "alpha_power": alpha_power,
        "sigma_power": sigma_power,
        "beta_power": beta_power,
        "delta_relative": delta_power / total_power if total_power > 0 else 0,
        "theta_relative": theta_power / total_power if total_power > 0 else 0,
        "alpha_relative": alpha_power / total_power if total_power > 0 else 0,
        "sigma_relative": sigma_power / total_power if total_power > 0 else 0,
        "beta_relative": beta_power / total_power if total_power > 0 else 0,
    }, label


def feature_extract_welch_only_rel(raw_cleaned_df, epoch, channel):
    epoch_filt = raw_cleaned_df["epoch"] == epoch
    s = raw_cleaned_df[epoch_filt][channel].to_numpy()
    t = raw_cleaned_df[epoch_filt]["time"]

    fs = 256
    nperseg = 2 * fs
    # The 'welch' function returns frequencies and power spectral density
    fft_freqs, intensity = signal.welch(s, fs=fs, nperseg=nperseg)

    label = raw_cleaned_df[epoch_filt]['condition'].unique()[0]

    peak_freq = np.abs(fft_freqs[np.argmax(intensity)])
    total_power = intensity.sum()
    delta_power = get_band_power(fft_freqs, intensity, "delta")
    theta_power = get_band_power(fft_freqs, intensity, "theta")
    alpha_power = get_band_power(fft_freqs, intensity, "alpha")
    sigma_power = get_band_power(fft_freqs, intensity, "sigma")
    beta_power = get_band_power(fft_freqs, intensity, "beta")

    return {
        "peak_freq": peak_freq,
        "total_power": total_power,
        "delta_relative": delta_power / total_power if total_power > 0 else 0,
        "theta_relative": theta_power / total_power if total_power > 0 else 0,
        "alpha_relative": alpha_power / total_power if total_power > 0 else 0,
        "sigma_relative": sigma_power / total_power if total_power > 0 else 0,
        "beta_relative": beta_power / total_power if total_power > 0 else 0,
    }, label


In [324]:
res_df = {}
label_df = {}

# channels_to_keep = [
#         "F4:M1", "C4:M1", 'O1:M2', 
#         'F3:M2', 'C3:M2', 'O2:M1']



# for epoch in raw_cleaned_df["epoch"].unique():
#     res_df[epoch], label_df[epoch] = feature_extract(raw_cleaned_df, epoch, 'C4:A1')

channels_to_keep = ['C3:A2', 'C4:A1', 'O2:A1', 
                    'F4:A1', 'O1:A2', 'F3:A2']

all_epoch_features = [] 

for epoch in raw_cleaned_df["epoch"].unique():
    features_for_this_epoch = {} 
    label = None 
    for ch in channels_to_keep:
        features, label = feature_extract_welch(raw_cleaned_df, epoch, ch)
        for feature_name, value in features.items():
            features_for_this_epoch[f"{ch}_{feature_name}"] = value
    features_for_this_epoch['label'] = label

    all_epoch_features.append(features_for_this_epoch)

final_df = pd.DataFrame(all_epoch_features)



In [327]:
#Split data into

# X = pd.DataFrame(res_df).T
# y = pd.Series(label_df)

X = final_df.drop('label', axis=1)
y = final_df['label']

train_size = int(0.8*len(X))
test_size = len(X) - train_size

shuffle = np.random.choice(np.arange(len(X)), len(X), replace=False)
train_idx = shuffle[0:train_size]
test_idx = shuffle[train_size:]

In [337]:
model = GradientBoostingClassifier()
model.fit(X.iloc[train_idx], y.iloc[train_idx])

train_pred = model.predict(X.iloc[train_idx])
test_pred = model.predict(X.iloc[test_idx])

train_correct = train_pred == y.iloc[train_idx]
test_correct = test_pred == y.iloc[test_idx]

train_accuracy = train_correct.mean()
test_accuracy = test_correct.mean()

print(f"Training Accuracy: {train_accuracy:.4f}")
print(f"Testing Accuracy: {test_accuracy:.4f}")

Training Accuracy: 1.0000
Testing Accuracy: 0.7674


In [339]:
unique_labels = np.unique(train_pred)
confusion_train = pd.DataFrame(index = unique_labels, columns=unique_labels)
confusion_test = pd.DataFrame(index = unique_labels, columns=unique_labels)

for col in confusion_train.columns:
    for row in confusion_train.index:
        confusion_train.loc[row, col] = (y.iloc[train_idx][train_pred == row] == col).sum()

for col in confusion_test.columns:
    for row in confusion_test.index:
        confusion_test.loc[row, col] = (y.iloc[test_idx][test_pred == row] == col).sum()



In [340]:
confusion_train




Unnamed: 0,A,N1,N2,N3,REM,Wake
A,1,0,0,0,0,0
N1,0,156,0,0,0,0
N2,0,0,349,0,0,0
N3,0,0,0,89,0,0
REM,0,0,0,0,143,0
Wake,0,0,0,0,0,412


In [341]:
confusion_test


Unnamed: 0,A,N1,N2,N3,REM,Wake
A,0,0,0,0,0,2
N1,0,25,7,0,8,6
N2,0,4,68,0,1,7
N3,0,0,4,22,0,1
REM,0,3,0,0,26,1
Wake,0,10,8,1,4,80


In [287]:
#with one channel 0.59 
#with all on patient 1: 0.55
#with one channel and welch 0.61
#all channels and subject 3  and welch - 0.76
#fft and all channels and subject 3 - 0.663
#welch with all channels and relative pds 0.82
#welch all channel and only relative 0.74 
