In [1]:
from utils import *
import pandas as pd
import mne
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.preprocessing import normalize, StandardScaler
plt.style.use('seaborn-whitegrid')

In [2]:
def get_threshold(scores):
    scores = np.array(scores)
    lower_threshold = scores.mean() - (scores.std()/2)
    upper_threshold = scores.mean() + (scores.std()/2)
    return scores.mean(), (lower_threshold,upper_threshold)

def get_stress_type(score, grade):
    """ Non-stress (0): score < lower_threshold
        Neutral    (1): lower_threshold <= score <= upper_threshold
        Stress     (2): score > lower_threshold """
    if(score < grade[0]):
        return 0
    elif(score <= grade[1]):
        return 1
    elif(score > grade[1]):
        return 2

def PSS_printer(PSS):
    # peak at info
    temp = PSS.popitem()
    PSS[temp[0]] = temp[1]
    column = list(temp[1].keys())
    space = "\t\t"
    print(f"Name{space}",f"{space}".join(column),sep="" )
    print("="*60)
    for name, info in PSS.items():
        print(f"{name}{space}",sep="",end="")
        for col in column:
            print(f"{info[col]}{space}",end="")

        print()


TYPE_DEF = {0:'Non-Stress', 1:'Neutral', 2: 'Stress'}

In [3]:
PSS = dict()
scores = []
with open('./PSS_scores.csv','r') as f:
    f.readline() # skip header
    for line in f.readlines(): 
        name,score = line.split(',')
        PSS[name] = {'score':int(score)}
        scores.append(int(score))

mean, grade = get_threshold(scores)
# print(f"Total={len(PSS)} | Mean={mean} | Lower Thres={grade[0]} | Higher Thres={grade[1]}")

type_count = {0:0, 1:0, 2:0}
for name, dict_info in PSS.items():
    label = get_stress_type(dict_info['score'], grade)
    dict_info['type'] = label
    dict_info['type_definition'] = TYPE_DEF[label]
    type_count[label] = type_count[label] + 1

# print(f"Non Stress={type_count[0]} | Neutral={type_count[1]} | Stress={type_count[2]}")

# PSS_printer(PSS)

In [5]:
sampling_rate = 125 #Hz
files = glob(f"data/*.csv")
for f in tqdm(files):
    name = f.split('/')[1].split('__')[0]
    pd_raw = pd.read_csv(f, dtype={'Marker':str})
    pd_raw = pd_raw.drop(columns='timestamps')
    raw = dataframe_to_raw(pd_raw, sfreq=sampling_rate)
    PSS[name]['raw'] = raw
    # print(f"{name} | time: {len(pd_raw)/125}")

  0%|          | 0/56 [00:00<?, ?it/s]

In [19]:
# save(PSS,"PSS")

PSS = load("PSS")

In [21]:
for name, info in tqdm(PSS.items()):
    raw = info['raw']
    raw.filter(l_freq=1,h_freq=None, method='iir', iir_params={'order':3.0, 'ftype':'butter'}, verbose=False) # Slow drift
    raw.notch_filter(freqs=[50])
    # epochs = mne.Epochs(raw, np.array([[125*60*1, 0, 1]]), tmin=0, tmax=30, baseline=(0,30), verbose=False)
    # print(name)
    # a = epochs.plot_psd(picks=['F3','F4','T3','T4'])
    # print("="*40)

  0%|          | 0/55 [00:00<?, ?it/s]

In [23]:
def get_freq(PSS):
    # peak at info
    temp = PSS.popitem()
    PSS[temp[0]] = temp[1]
    raw = temp[1]['raw']
    power,freq = mne.time_frequency.psd_welch(raw,n_fft=125, verbose=True)
    return freq

freq = get_freq(PSS)
print(freq)

Effective window size : 1.000 (s)
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.
 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53.
 54. 55. 56. 57. 58. 59. 60. 61. 62.]


In [24]:
band_names = np.array(['Delta', 'Theta', 'Alpha', 'Beta', 'Gamma', 'Slow', 'Low_beta'])
filter_list = [[1,3],[4,7],[8,12],[13,30],[30,43], [4,13], [13,17]]
bands = []
for filt in filter_list:
    pt = np.argwhere((freq >= filt[0]) & (freq <= filt[1])).reshape(-1)
    bands.append(pt)
bands = np.array(bands)
print(bands)

[array([1, 2, 3]) array([4, 5, 6, 7]) array([ 8,  9, 10, 11, 12])
 array([13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
        30])
 array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43])
 array([ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13])
 array([13, 14, 15, 16, 17])]


  bands = np.array(bands)


In [69]:
# features = None
csv = None
names = []
for name,info in PSS.items():
    names.append(name)
    raw = info['raw']
    row = None
    feature = None
    slow, gamma = None, None
    a_f3, a_f4 = None, None
    a_t7, a_t8 = None, None
    b_f3, b_f4 = None, None
    b_t7, b_t8 = None, None
    # epochs = mne.Epochs(raw, np.array([[125*60*1, 0, 1]]), tmin=0, tmax=30, baseline=(0,30), verbose=False)
    for index, band in enumerate(bands):
        power,freq = mne.time_frequency.psd_welch(raw,n_fft=125, verbose=False)
        power = power.squeeze()
        power = 10 * np.log10(power)
        data = power[::,band].mean(axis=1).reshape(1,-1)
        # print(f"{data.shape=}")

        if(type(row) == type(None)): row = data.copy()
        else: row = np.concatenate([row,data.copy()], axis=1)

        # print(f"{row.shape=}")
        # for asym
        if(band_names[index] == 'Alpha'):
            a_f3 = data[:,raw.ch_names.index('F3')]
            a_f4 = data[:,raw.ch_names.index('F4')]
            # We use t3 as t7 and t4 as t8
            a_t7 = data[:,raw.ch_names.index('T3')]
            a_t8 = data[:,raw.ch_names.index('T4')]
        if(band_names[index] == 'Beta'):
            b_f3 = data[:,raw.ch_names.index('F3')]
            b_f4 = data[:,raw.ch_names.index('F4')]
            # We use t3 as t7 and t4 as t8
            b_t7 = data[:,raw.ch_names.index('T3')]
            b_t8 = data[:,raw.ch_names.index('T4')]

        ####### Mean for visualization #######
        data = data.mean().reshape(1,-1)
        # for relative gamma
        if(band_names[index] == 'Slow'): slow = data
        if(band_names[index] == 'Gamma'): gamma = data

        if(type(feature) == type(None)): feature = data
        else: feature = np.concatenate([feature, data], axis=1)
    # print(feature.shape)
    # the eighth feature: relative gamma is slow/gamma
    relative_gamma = slow/gamma
    feature = np.concatenate([feature, relative_gamma], axis=1)
    # The asymetry
    alpha_frontal = ((a_f4 - a_f3) / (a_f4 + a_f3)).reshape(1,-1)
    feature = np.concatenate([feature, alpha_frontal], axis=1)
    # alpha_temporal
    alpha_temporal = ((a_t8 - a_t7) / (a_t8 + a_t7)).reshape(1,-1)
    feature = np.concatenate([feature, alpha_temporal], axis=1)
    # alpha_asymmetry
    alpha_asymmetry = alpha_frontal + alpha_temporal
    feature = np.concatenate([feature, alpha_asymmetry], axis=1)
    # beta_frontal
    beta_frontal = ((b_f4 - b_f3) / (b_f4 + b_f3)).reshape(1,-1)
    feature = np.concatenate([feature, beta_frontal], axis=1)
    # beta_temporal
    beta_temporal = ((b_t8 - b_t7) / (b_t8 + b_t7)).reshape(1,-1)
    feature = np.concatenate([feature, beta_temporal], axis=1)

    row = np.concatenate([row, relative_gamma, alpha_frontal, alpha_asymmetry, beta_frontal, beta_temporal], axis=1)
    # print(slow/gamma)
    # print(feature.shape)
    # print(feature)
    info['feature'] = feature
    if(type(csv) == type(None)): csv = row
    else: csv = np.concatenate([csv,row], axis=0)
    print(f"{csv.shape=}")
    # break
    # if(type(features) == type(None)): features = feature
    # else: features = np.concatenate([features, feature], axis=0)
# print(features.shape)


csv.shape=(1, 117)
csv.shape=(2, 117)
csv.shape=(3, 117)
csv.shape=(4, 117)
csv.shape=(5, 117)
csv.shape=(6, 117)
csv.shape=(7, 117)
csv.shape=(8, 117)
csv.shape=(9, 117)
csv.shape=(10, 117)
csv.shape=(11, 117)
csv.shape=(12, 117)
csv.shape=(13, 117)
csv.shape=(14, 117)
csv.shape=(15, 117)
csv.shape=(16, 117)
csv.shape=(17, 117)
csv.shape=(18, 117)
csv.shape=(19, 117)
csv.shape=(20, 117)
csv.shape=(21, 117)
csv.shape=(22, 117)
csv.shape=(23, 117)
csv.shape=(24, 117)
csv.shape=(25, 117)
csv.shape=(26, 117)
csv.shape=(27, 117)
csv.shape=(28, 117)
csv.shape=(29, 117)
csv.shape=(30, 117)
csv.shape=(31, 117)
csv.shape=(32, 117)
csv.shape=(33, 117)
csv.shape=(34, 117)
csv.shape=(35, 117)
csv.shape=(36, 117)
csv.shape=(37, 117)
csv.shape=(38, 117)
csv.shape=(39, 117)
csv.shape=(40, 117)
csv.shape=(41, 117)
csv.shape=(42, 117)
csv.shape=(43, 117)
csv.shape=(44, 117)
csv.shape=(45, 117)
csv.shape=(46, 117)
csv.shape=(47, 117)
csv.shape=(48, 117)
csv.shape=(49, 117)
csv.shape=(50, 117)
csv.shape

In [68]:
with open("feature.csv","w") as f:
    # print(raw.ch_names[:-1])
    # print(band_names)
    header = ["name"]
    for band_name in band_names:
        for ch_name in raw.ch_names[:-1]:
            header.append(f"{ch_name}_{band_name}")
    f.write(",".join(header))


In [56]:
np.savetxt("foo.csv", csv, delimiter=",")

In [71]:
print(names)

['fabby', 'bas', 'flm', 'mind', 'taew', 'MJ', 'nopphon', 'boss', 'film', 'new', 'nice', 'nuclear', 'pang', 'prin', 'amp', 'beau', 'dt', 'int', 'minkhant', 'sam', 'yong', 'aui', 'bank', 'dream', 'eiyu', 'ice', 'job', 'kee', 'miiw', 'noey', 'pear', 'por', 'satya', 'shin', 'suyo', 'tom', 'yee', 'aun', 'bam', 'beer', 'cedric', 'fahmai', 'gon', 'harold', 'kant', 'kao', 'mu', 'nisit', 'pla', 'ploy', 'poon', 'praewphan', 's', 'younten', 'tor']
