In [1]:
import pandas
import numpy as np
from utils import *
import mne
import matplotlib.pyplot as plt
from os import walk
from tqdm.notebook import tqdm

from sklearn.naive_bayes import GaussianNB
from itertools import combinations
from mne.time_frequency import tfr_morlet

plt.style.use('seaborn-whitegrid')

df = pandas.read_csv('./HEXACO.csv')
# Honesty-Humility	Emotionality	eXtraversion	Agreeableness	Conscientiousness	Openness to Experience
gt = df[['id','Honesty-Humility','Emotionality','eXtraversion','Agreeableness','Conscientiousness','Openness to Experience']].rename(columns={'Honesty-Humility':'h',
                                  'Emotionality':'e',
                                  'eXtraversion':'x',
                                 'Agreeableness':'a',
                             'Conscientiousness':'c',
                        'Openness to Experience':'o'}).set_index('id')
thold = 3.5
gt['lh'] = (gt[['h']] > thold) * 1
gt['le'] = (gt[['e']] > thold) * 1
gt['lx'] = (gt[['x']] > thold) * 1
gt['la'] = (gt[['a']] > thold) * 1
gt['lc'] = (gt[['c']] > thold) * 1
gt['lo'] = (gt[['o']] > thold) * 1

In [2]:
path, folders, filenames = next(walk('./data'))
# Exclude these data because of incomplete data
gt.drop([11,4,36],inplace=True)

filenames.remove('11-audio.csv')
filenames.remove('11-image.csv')

filenames.remove('36-audio.csv')
filenames.remove('36-image.csv')

In [3]:
path = './data'
columns = {'Unnamed: 1':'Fp1',
        'Unnamed: 2':'Fp2',
        'Unnamed: 3':'F3',
        'Unnamed: 4':'F4',
        'Unnamed: 5':'F7',
        'Unnamed: 6':'F8',
        'Unnamed: 7':'P7',
        'Unnamed: 8':'P8'}

EEG_audio, EEG_image = dict(), dict()
from itertools import product
categories = [1,2,3,4,5]
blocks = [1,2]
for filename in tqdm(filenames):
    participant_id, stimuli = filename.split('-')
    stimuli = stimuli.rstrip('.csv')
    data = pandas.read_csv(f'{path}/{filename}', dtype={'Marker': str}).rename(columns=columns).drop(columns='timestamps')
    # print(participant_id, stimuli)
    # experiment = dict()
    # for (category, block) in product(categories,blocks):
        # print("   ", category, block)
        # section = get_section_from_catblock(data, category=category,block=block)
        # experiment[f"{category}_{block}"] = section
    if(stimuli == 'audio'):
        EEG_audio[int(participant_id)] = data
    elif(stimuli == 'image'):
        EEG_image[int(participant_id)] = data
    else:
        raise ValueError(f"Stimuli:{stimuli} is unexpected.")

  0%|          | 0/60 [00:00<?, ?it/s]

In [21]:
# Preprocess data + feature extraction using wavelet 'morlet'
def get_data_wt(eeg_type = 'image'):
    X = None
    # y = []
    Y = None
    participant_id = []
    categories = []
    # 'DELTA' 'THETA' 'ALPHA' 'BETA' 'Gamma'
    # https://reader.elsevier.com/reader/sd/pii/S0957417410005695?token=99F7CC487CECF9C17E36713347D0F8372A289AC2C6331A2B6C4F272CD34921FD8B418EB417C0F9C3796CEA271FB4455D&originRegion=eu-west-1&originCreation=20210510142445
    filter_list = np.array([4,8,13,30,125])
    # ids = [33,2,10,12,16]
    ids = gt.index.tolist()
    for id in tqdm(ids):
        # print('\n',"="*20,id)
        if(eeg_type == 'image'):
            raw = dataframe_to_raw(EEG_image[id], sfreq=250)
        else:
            raw = dataframe_to_raw(EEG_audio[id], sfreq=250)

        raw.notch_filter([50,100],filter_length='auto', phase='zero', verbose=False) # Line power
        raw.filter(1., None, fir_design='firwin', verbose=False) # Slow drift

        events = mne.find_events(raw, stim_channel='Marker', initial_event=True, verbose=False)
        events = np.delete(events,np.argwhere(events[:,2] == 1), axis=0)
        if(events.shape[0] != 50):
            raise ValueError(f"Event missing: {events[:,2]}. len(events.shape[0])={events.shape[0]}")
        epochs = mne.Epochs(raw, events, tmin=0, tmax=5.8, baseline=(0.3,0.3), verbose=False)
        if(epochs.get_data().shape[0] != 50):
            raise ValueError(f"There might be a bad data. epochs.get_data().shape = {epochs.get_data().shape}")

        powers = tfr_morlet(epochs, freqs=filter_list, n_cycles=filter_list / 2., return_itc=False, average=False, verbose=False)
        # features = np.mean(powers.data, axis=3)
        features = powers.data
        # print(features.shape) #(50,8,5,)
        for e in range(features.shape[0]):
            row = np.expand_dims(features[e], axis=0)
            row = 10 * np.log10(row)
            if(type(X) == type(None)): X = row
            else: X = np.concatenate( [X, row ], axis=0 )
            # y.append(gt.loc[id]['label'])

            label = gt.loc[id][['lh','le','lx','la','lc','lo']].to_numpy()
            label = np.expand_dims(label, axis=0)
            if(type(Y) == type(None)): Y = label
            else: Y = np.concatenate( [Y, label ], axis=0 )
            participant_id.append(id)
        for e in events[:,2]:
            categories.append(str(e)[0])
        print(X.shape)
            # print(Y.shape)
    # y = np.array(y)
    return X,Y,np.array(participant_id),np.array(categories)

In [22]:
import warnings
warnings.filterwarnings("ignore")

channels = list(columns.values())
band_name = np.array(['DELTA','THETA','ALPHA','BETA','Gamma'])
X_head = dict()
count = 0
for channel in channels:
    for band in band_name:
        X_head[count] = (f"{channel}-{band}")
        count += 1
print(X_head)

{0: 'Fp1-DELTA', 1: 'Fp1-THETA', 2: 'Fp1-ALPHA', 3: 'Fp1-BETA', 4: 'Fp1-Gamma', 5: 'Fp2-DELTA', 6: 'Fp2-THETA', 7: 'Fp2-ALPHA', 8: 'Fp2-BETA', 9: 'Fp2-Gamma', 10: 'F3-DELTA', 11: 'F3-THETA', 12: 'F3-ALPHA', 13: 'F3-BETA', 14: 'F3-Gamma', 15: 'F4-DELTA', 16: 'F4-THETA', 17: 'F4-ALPHA', 18: 'F4-BETA', 19: 'F4-Gamma', 20: 'F7-DELTA', 21: 'F7-THETA', 22: 'F7-ALPHA', 23: 'F7-BETA', 24: 'F7-Gamma', 25: 'F8-DELTA', 26: 'F8-THETA', 27: 'F8-ALPHA', 28: 'F8-BETA', 29: 'F8-Gamma', 30: 'P7-DELTA', 31: 'P7-THETA', 32: 'P7-ALPHA', 33: 'P7-BETA', 34: 'P7-Gamma', 35: 'P8-DELTA', 36: 'P8-THETA', 37: 'P8-ALPHA', 38: 'P8-BETA', 39: 'P8-Gamma'}


In [23]:
import pickle
for eeg_type in ['image','audio']:
    X,Y,participant_id,categories = get_data_wt(eeg_type)
    data = {"Y":Y, "part_id":participant_id, "cat":categories}
    print(X.shape,Y.shape,participant_id.shape,categories.shape)
    with open(f'data_extract_raw/{eeg_type}_X.pickle', 'wb') as handle:
        pickle.dump(X, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(f'data_extract_raw/{eeg_type}_data.pickle', 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

  0%|          | 0/30 [00:00<?, ?it/s]

(50, 8, 5, 1451)
(100, 8, 5, 1451)
(150, 8, 5, 1451)
(200, 8, 5, 1451)
(250, 8, 5, 1451)
(300, 8, 5, 1451)
(350, 8, 5, 1451)
(400, 8, 5, 1451)
(450, 8, 5, 1451)
(500, 8, 5, 1451)
(550, 8, 5, 1451)
(600, 8, 5, 1451)
(650, 8, 5, 1451)
(700, 8, 5, 1451)
(750, 8, 5, 1451)
(800, 8, 5, 1451)
(850, 8, 5, 1451)
(900, 8, 5, 1451)
(950, 8, 5, 1451)
(1000, 8, 5, 1451)
(1050, 8, 5, 1451)
(1100, 8, 5, 1451)
(1150, 8, 5, 1451)
(1200, 8, 5, 1451)
(1250, 8, 5, 1451)
(1300, 8, 5, 1451)
(1350, 8, 5, 1451)
(1400, 8, 5, 1451)
(1450, 8, 5, 1451)
(1500, 8, 5, 1451)
(1500, 8, 5, 1451) (1500, 6) (1500,) (1500,)


  0%|          | 0/30 [00:00<?, ?it/s]

(50, 8, 5, 1451)
(100, 8, 5, 1451)
(150, 8, 5, 1451)
(200, 8, 5, 1451)
(250, 8, 5, 1451)
(300, 8, 5, 1451)
(350, 8, 5, 1451)
(400, 8, 5, 1451)
(450, 8, 5, 1451)
(500, 8, 5, 1451)
(550, 8, 5, 1451)
(600, 8, 5, 1451)
(650, 8, 5, 1451)
(700, 8, 5, 1451)
(750, 8, 5, 1451)
(800, 8, 5, 1451)
(850, 8, 5, 1451)
(900, 8, 5, 1451)
(950, 8, 5, 1451)
(1000, 8, 5, 1451)
(1050, 8, 5, 1451)
(1100, 8, 5, 1451)
(1150, 8, 5, 1451)
(1200, 8, 5, 1451)
(1250, 8, 5, 1451)
(1300, 8, 5, 1451)
(1350, 8, 5, 1451)
(1400, 8, 5, 1451)
(1450, 8, 5, 1451)
(1500, 8, 5, 1451)
(1500, 8, 5, 1451) (1500, 6) (1500,) (1500,)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,58030,58031,58032,58033,58034,58035,58036,58037,58038,58039
0,-77.376939,-76.827312,-76.295049,-75.781008,-75.285939,-74.810491,-74.355215,-73.920577,-73.506956,-73.114660,...,-119.493397,-119.753710,-120.041524,-120.357253,-120.701013,-121.073340,-121.474437,-121.904974,-122.365095,-122.855713
1,-93.707040,-93.459220,-93.224377,-93.001355,-92.789036,-92.586364,-92.392376,-92.206233,-92.027242,-91.854884,...,-125.561423,-123.845042,-122.472105,-121.344799,-120.401526,-119.604238,-118.925068,-118.345582,-117.850855,-117.430721
2,-76.857225,-76.857637,-76.882954,-76.932623,-77.005993,-77.102299,-77.220647,-77.359983,-77.519080,-77.696500,...,-114.859311,-114.285641,-113.766125,-113.297300,-112.875326,-112.497765,-112.161876,-111.865937,-111.607858,-111.386413
3,-86.382559,-86.349949,-86.343750,-86.363837,-86.410073,-86.482310,-86.580380,-86.704094,-86.853243,-87.027586,...,-112.640844,-112.489836,-112.362132,-112.257866,-112.176761,-112.118891,-112.083972,-112.072102,-112.082986,-112.116696
4,-93.430395,-93.262254,-93.098473,-92.939260,-92.785156,-92.637008,-92.495938,-92.363294,-92.240595,-92.129483,...,-118.571094,-118.441801,-118.318086,-118.202412,-118.096235,-118.001579,-117.919724,-117.852277,-117.800276,-117.765079
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1495,-91.486104,-91.118980,-90.775145,-90.454684,-90.157690,-89.884256,-89.634474,-89.408432,-89.206206,-89.027857,...,-124.238968,-122.535521,-121.179291,-120.071085,-119.150010,-118.377138,-117.724898,-117.174292,-116.710606,-116.323049
1496,-80.837250,-80.758952,-80.707933,-80.684114,-80.687410,-80.717718,-80.774924,-80.858896,-80.969484,-81.106520,...,-109.277858,-108.878650,-108.510407,-108.172726,-107.864874,-107.586600,-107.337243,-107.116654,-106.924255,-106.759967
1497,-84.479651,-84.410646,-84.357065,-84.317655,-84.291071,-84.275883,-84.270574,-84.273544,-84.283120,-84.297569,...,-108.798060,-108.757893,-108.741912,-108.749466,-108.779730,-108.832191,-108.906188,-109.001347,-109.117052,-109.253135
1498,-82.087909,-82.127718,-82.193493,-82.284590,-82.400239,-82.539523,-82.701360,-82.884472,-83.087362,-83.308279,...,-135.739766,-137.165012,-139.068459,-141.731890,-145.925862,-155.054359,-155.654213,-145.767023,-141.206829,-138.223097
