In [5]:
from utils import *
import pandas as pd
import mne
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.preprocessing import normalize, StandardScaler
plt.style.use('seaborn-whitegrid')

# Prepare Data

In [6]:
def get_threshold(scores):
    scores = np.array(scores)
    lower_threshold = scores.mean() - (scores.std()/2)
    upper_threshold = scores.mean() + (scores.std()/2)
    return scores.mean(), (lower_threshold,upper_threshold)

def get_stress_type(score, grade):
    """ Non-stress (0): score < lower_threshold
        Neutral    (1): lower_threshold <= score <= upper_threshold
        Stress     (2): score > lower_threshold """
    if(score < grade[0]):
        return 0
    elif(score <= grade[1]):
        return 1
    elif(score > grade[1]):
        return 2

def PSS_printer(PSS):
    # peak at info
    temp = PSS.popitem()
    PSS[temp[0]] = temp[1]
    column = list(temp[1].keys())
    space = "\t\t"
    print(f"Name{space}",f"{space}".join(column),sep="" )
    print("="*60)
    for name, info in PSS.items():
        print(f"{name}{space}",sep="",end="")
        for col in column:
            print(f"{info[col]}{space}",end="")

        print()


TYPE_DEF = {0:'Non-Stress', 1:'Neutral', 2: 'Stress'}

In [7]:
PSS = dict()
scores = []
with open('./PSS_scores.csv','r') as f:
    f.readline() # skip header
    for line in f.readlines(): 
        name,score = line.split(',')
        PSS[name] = {'score':int(score)}
        scores.append(int(score))

mean, grade = get_threshold(scores)
# print(f"Total={len(PSS)} | Mean={mean} | Lower Thres={grade[0]} | Higher Thres={grade[1]}")

type_count = {0:0, 1:0, 2:0}
for name, dict_info in PSS.items():
    label = get_stress_type(dict_info['score'], grade)
    dict_info['type'] = label
    dict_info['type_definition'] = TYPE_DEF[label]
    type_count[label] = type_count[label] + 1

# print(f"Non Stress={type_count[0]} | Neutral={type_count[1]} | Stress={type_count[2]}")

# PSS_printer(PSS)

In [13]:
try:
    PSS = load("PSS")
except:
    sampling_rate = 125 #Hz
    files = glob(f"data/*.csv")
    for f in tqdm(files):
        name = f.split('/')[1].split('__')[0]
        pd_raw = pd.read_csv(f, dtype={'Marker':str})
        pd_raw = pd_raw.drop(columns='timestamps')
        raw = dataframe_to_raw(pd_raw, sfreq=sampling_rate)
        PSS[name]['raw'] = raw
        # print(f"{name} | time: {len(pd_raw)/125}")

    save(PSS,"PSS")



# Processing

In [17]:
for name, info in tqdm(PSS.items()):
    raw = info['raw']
    raw.filter(l_freq=1,h_freq=None, method='iir', iir_params={'order':3.0, 'ftype':'butter'}, verbose=False) # Slow drift
    raw.notch_filter(freqs=[50])
    # epochs = mne.Epochs(raw, np.array([[125*60*1, 0, 1]]), tmin=0, tmax=30, baseline=(0,30), verbose=False)
    # print(name)
    # a = epochs.plot_psd(picks=['F3','F4','T3','T4'])
    # print("="*40)

def get_freq(PSS):
    # peak at info
    temp = PSS.popitem()
    PSS[temp[0]] = temp[1]
    raw = temp[1]['raw']
    power,freq = mne.time_frequency.psd_welch(raw,n_fft=125, verbose=True)
    return freq

freq = get_freq(PSS)
print(f"{freq=}")

band_names = np.array(['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma', 'Slow', 'Low_beta'])
filter_list = [[1,3],[4,7],[8,12],[13,30],[30,43], [4,13], [13,17]]
bands = []
for filt in filter_list:
    pt = np.argwhere((freq >= filt[0]) & (freq <= filt[1])).reshape(-1)
    bands.append(pt)
bands = np.array(bands)
print(f"{bands=}")

  0%|          | 0/55 [00:00<?, ?it/s]

Effective window size : 1.000 (s)
freq=array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
       39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,
       52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62.])
bands=array([array([1, 2, 3]), array([4, 5, 6, 7]), array([ 8,  9, 10, 11, 12]),
       array([13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
              30])                                                               ,
       array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43]),
       array([ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13]),
       array([13, 14, 15, 16, 17])], dtype=object)


  bands = np.array(bands)


In [18]:
# features = None
for name,info in PSS.items():
    raw = info['raw']
    feature = None
    slow, gamma = None, None
    a_f3, a_f4 = None, None
    a_t7, a_t8 = None, None
    b_f3, b_f4 = None, None
    b_t7, b_t8 = None, None
    
    epochs = mne.Epochs(raw, np.array([[125*60*1, 0, 1],[125*60*1.5,0,1]]), tmin=0, tmax=30, baseline=(0,30), verbose=False)
    for evoke in epochs.iter_evoked:
    for index, band in enumerate(bands):
        power,freq = mne.time_frequency.psd_welch(epochs,n_fft=125, verbose=False)
        power = power.squeeze()
        power = 10 * np.log10(power)
        data = power[::,band].mean(axis=1).reshape(1,-1)
        # for asym
        if(band_names[index] == 'Alpha'):
            a_f3 = data[:,raw.ch_names.index('F3')]
            a_f4 = data[:,raw.ch_names.index('F4')]
            # We use t3 as t7 and t4 as t8
            a_t7 = data[:,raw.ch_names.index('T3')]
            a_t8 = data[:,raw.ch_names.index('T4')]
        if(band_names[index] == 'Beta'):
            b_f3 = data[:,raw.ch_names.index('F3')]
            b_f4 = data[:,raw.ch_names.index('F4')]
            # We use t3 as t7 and t4 as t8
            b_t7 = data[:,raw.ch_names.index('T3')]
            b_t8 = data[:,raw.ch_names.index('T4')]

        ####### Mean for visualization #######
        data = data.mean().reshape(1,-1)
        # for relative gamma
        if(band_names[index] == 'Slow'): slow = data
        if(band_names[index] == 'Gamma'): gamma = data

        if(type(feature) == type(None)): feature = data
        else: feature = np.concatenate([feature, data], axis=1)
    # print(feature.shape)
    # the eighth feature: relative gamma is slow/gamma
    relative_gamma = slow/gamma
    feature = np.concatenate([feature, relative_gamma], axis=1)
    # The asymetry
    alpha_frontal = ((a_f4 - a_f3) / (a_f4 + a_f3)).reshape(1,-1)
    feature = np.concatenate([feature, alpha_frontal], axis=1)
    # alpha_temporal
    alpha_temporal = ((a_t8 - a_t7) / (a_t8 + a_t7)).reshape(1,-1)
    feature = np.concatenate([feature, alpha_temporal], axis=1)
    # alpha_asymmetry
    alpha_asymmetry = alpha_frontal + alpha_temporal
    feature = np.concatenate([feature, alpha_asymmetry], axis=1)
    # beta_frontal
    beta_frontal = ((b_f4 - b_f3) / (b_f4 + b_f3)).reshape(1,-1)
    feature = np.concatenate([feature, beta_frontal], axis=1)
    # beta_temporal
    beta_temporal = ((b_t8 - b_t7) / (b_t8 + b_t7)).reshape(1,-1)
    feature = np.concatenate([feature, beta_temporal], axis=1)

    # print(slow/gamma)
    # print(feature.shape)
    # print(feature)
    info['feature'] = feature
    # if(type(features) == type(None)): features = feature
    # else: features = np.concatenate([features, feature], axis=0)
# print(features.shape)


# Build Models

In [19]:
# plt.plot(feature)
feature_names = list(band_names)
feature_names.append('Relative_Gamma')
feature_names.append('Alpha_Frontal')
feature_names.append('Alpha_Temporal')
feature_names.append('Alpha_Asymmetry')
feature_names.append('Beta_Frontal')
feature_names.append('Beta_Temporal')
feature_names = np.array(feature_names)
feature_names[[3,10]]
X_ori,y_ori = [], []
filtered_participants = []
filtered_scored = []
for index,(name,info) in enumerate(PSS.items()):
    # Neutral
    if(info['type'] == 1): continue
    # Non-Stress
    elif(info['type'] == 0):
        y_ori.append(0)
    # Stress
    elif(info['type'] == 2):
        y_ori.append(1)
    X_ori.append(info['feature'])
    filtered_participants.append(name)
    filtered_scored.append(info['score'])

In [20]:
def NormJa(data):
    for index, row in enumerate(data):
        min = row.min()
        max = row.max()
        mean = row.mean()
        row = (row - min) / (max - min)
        data[index] = row
        # print(row)
    return data

def StandardJa(data):
    data
    for index, row in enumerate(data):
        mean = row.mean()
        std = row.std()
        row = (row - mean) / std
        data[index] = row
        # print(row)
    return data

In [21]:
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
X,y = np.array(X_ori).squeeze(axis=1), np.array(y_ori)
# X = normalize(X.copy(), axis=0)
# X = NormJa(X.copy().T).T
X = StandardJa(X.copy().T).T
X_shuff,y_shuff = shuffle(X,y)
print(X.shape, y.shape)

param_grid = dict(kernel=['linear','poly','rbf', 'sigmoid'])#,'precomputed'])
# cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=5)
grid.fit(X, y)
print(f"The best parameters are {grid.best_params_} with a score of {grid.best_score_:.2f}")

model = SVC(kernel='rbf')
# model = GaussianNB()
model.fit(X_shuff, y_shuff)
ans = model.predict(X_shuff)
acc = sum(ans == y_shuff) / len(y_shuff)
cross = cross_val_score(model, X_shuff, y_shuff, cv=5)
print(acc, cross.mean(), cross)
print(ans)

(35, 13) (35,)
The best parameters are {'kernel': 'rbf'} with a score of 0.60
0.8285714285714286 0.42857142857142855 [0.57142857 0.28571429 0.14285714 0.71428571 0.42857143]
[1 0 1 0 0 0 1 1 1 1 0 0 1 0 1 1 1 1 1 0 0 1 1 1 1 1 0 1 0 1 1 1 0 1 1]


In [22]:
ans = model.predict(X)
acc = sum(ans == y) / len(y)
print(f"0: Non-Stress {type_count[0]}")
print(f"1: Stress {type_count[2]}")
print(f"Wrong\t|pred|label |Score |Name")
print("="*40)
for index,(i,j) in enumerate(zip(ans,y)):
    wrong = ""
    if(i != j):
        wrong = "X"
    print(f"{wrong}\t|{i}   |{j}     |{filtered_scored[index]}\t   |{filtered_participants[index]}")

0: Non-Stress 16
1: Stress 19
Wrong	|pred|label |Score |Name
	|1   |1     |25	   |fabby
	|1   |1     |25	   |bas
	|1   |1     |37	   |flm
X	|1   |0     |12	   |MJ
X	|0   |1     |28	   |new
	|0   |0     |15	   |nuclear
	|1   |1     |29	   |pang
	|0   |0     |9	   |prin
	|0   |0     |13	   |amp
	|0   |0     |9	   |beau
X	|1   |0     |16	   |dt
	|0   |0     |15	   |int
	|1   |1     |24	   |minkhant
	|1   |1     |27	   |yong
	|1   |1     |28	   |aui
	|1   |1     |25	   |bank
	|1   |1     |24	   |eiyu
	|0   |0     |13	   |job
	|1   |1     |25	   |kee
	|0   |0     |14	   |miiw
	|0   |0     |8	   |noey
	|0   |0     |13	   |shin
	|0   |0     |17	   |suyo
	|1   |1     |24	   |yee
	|1   |1     |24	   |bam
X	|1   |0     |17	   |beer
	|1   |1     |26	   |cedric
	|1   |1     |37	   |fahmai
	|1   |1     |27	   |gon
	|1   |1     |24	   |kao
X	|1   |0     |16	   |mu
X	|1   |0     |11	   |nisit
	|1   |1     |27	   |pla
	|1   |1     |27	   |ploy
	|0   |0     |16	   |praewphan
