In [325]:
import pandas as pd
import eeg
from eeg import unix_to_localdate, unix_to_period
import datetime
import scipy.fftpack
import numpy as np
import matplotlib.pyplot as plt


files = eeg.extractBundledEEG("../../scripts/data/")
files.addMeta("../../scripts/data/eeg-restingstate/events.csv","timeofday")
files.prune()
# files.mergeTagsWithRegex("[cC]lose","Eyes Closed")
files.categories
dataAgg = dict()

bands_ordered = ["delta", "theta", "alpha", "beta", "gamma"]

sanityTruncation  = {"alpha": 9.301913, #neurosity quality value can lag signal, allowing obscene outliers
"beta":       3.997140,
"delta":     33.608641,
"gamma":      0.992294,
"theta":     21.004026,
}

#TODO epoch generator

def load_session_epochs():
    pass

def load_session_summery(files: dict, qualityCutoffFilter: int = 0, epochSize: int = -1, returnEpoched = False) -> dict: # summery or epoch return
    """
    Takes a session of EEG data, computes some metrics, and returns them.
    qualityCutoffFilter: percentage of time electrode data is marked as "good" or "great" required for it to be included in output and analyticsl, 0 capture everything
    filterSampleWindow: window size to sample filters 

    Potential metrics:
     - average power by band
     - average power by channel
     - relative power by band
     - average focus/calm score
    """
    # Best channels are usually: CP3, CP4, PO3, PO4

    # NOTE: unixTimestamps are int/seconds but samples more often than 1Hz,
    #       so several rows per timestamp and missing sub-second resolution.
    df_pbb = pd.read_json(files["powerByBand"])
    df_sigQ =  pd.read_json(files["signalQuality"])
    df_pbb.set_index("unixTimestamp", inplace=True)
    df_sigQ.set_index("unixTimestamp", inplace=True,drop=False)
    df_pbb.index = pd.to_datetime(df_pbb.index, unit="s")
    df_sigQ.index = pd.to_datetime(df_sigQ.index, unit="s")




    if epochSize == -1: epochSize = (df_pbb.index[df_pbb.shape[0]-1]-df_pbb.index[0]).seconds

    # we have to deal with the channels, such as CP3_alpha, CP3_beta, etc.
    # for now, we will just average them all together
    channels, bands = zip(*[c.split("_") for c in df_pbb.columns])
    channels, bands = list(set(channels)), list(set(bands))

    temporalFilter = {x : pd.Series([True] * df_pbb.shape[0],index=df_pbb.index) for x in channels}
    for x in channels: df_pbb["epoch_"+x] = pd.NA

    removedChannels = []
    epochSize = datetime.timedelta(seconds=epochSize)

    oldTime = df_pbb.index[0]
    newTime = df_pbb.index[0] + epochSize

    endTime = df_pbb.index[df_pbb.shape[0]-1]

    if(newTime>endTime): newTime = endTime
    finalLoop = False
    # print(oldTime,newTime, df_pbb.shape[0],df_sigQ.shape[0])
    qualitySamplesMissing = 0
    validEpochsCount = { x : 0 for x in channels}


    while(endTime>=newTime):
        sigSamp = df_sigQ[(df_sigQ.index < newTime) & (df_sigQ.index >= oldTime)]
        # print(oldTime,newTime, df_sigQ[df_sigQ.index >= oldTime][["PO3_status","unixTimestamp"]].head(10))
        pbbSamp = (df_pbb.index < newTime) & (df_pbb.index >= oldTime)
        if finalLoop:  pbbSamp = (df_pbb.index >= oldTime)
        if sigSamp.empty: # if no quality data in epoch remove it
            # print("NO QUALIFYING COMPONET\n")
            for channel in channels: temporalFilter[channel][pbbSamp] = False
            if True in pbbSamp: qualitySamplesMissing+=1
        else:
            for channel in channels:
                col = channel + "_status"
                channel_states = sigSamp[col].value_counts()
                no_of_okay_samples = 0
                if 'good' in channel_states:
                    no_of_okay_samples += channel_states['good']
                if 'great' in channel_states:
                    no_of_okay_samples += channel_states['great']

                percentage_good = no_of_okay_samples / sigSamp.shape[0]
                if percentage_good<qualityCutoffFilter: temporalFilter[channel][pbbSamp] = False
                else: 
                    # if channel == "PO3": 
                    #     print(oldTime,newTime," ",validEpochsCount[channel])
                    #     print(pbbSamp[0:10])
                    df_pbb.loc[pbbSamp,"epoch_"+channel] = int(validEpochsCount[channel])
                    validEpochsCount[channel]+=1

        if((newTime+epochSize)>endTime and not finalLoop): 
            oldTime = newTime
            newTime = endTime
            finalLoop = True
        else:
            oldTime = newTime
            newTime = newTime + epochSize
    

    # print(df_pbb["epoch_PO3"].head(20))
    if qualitySamplesMissing >  ((endTime - df_pbb.index[0])/epochSize)*.2:
        print("Error, substantial amount of quality samples misaligned or missing")

    for x in channels: 
        # print(temporalFilter[x].value_counts())
        if True not in temporalFilter[x].value_counts():
            removedChannels.append(x)
            channels.remove(x)
        else:
            print(x,1-(temporalFilter[x].value_counts()[True]/temporalFilter[x].shape[0]),"lost")
            if (temporalFilter[x].value_counts()[True]<10):
                removedChannels.append(x)
                channels.remove(x)

    if removedChannels: print("Channels",*removedChannels,"removed")

    df = pd.DataFrame(index=df_pbb.index)
    for channel in channels:
        bands_for_channel = [c for c in df_pbb.columns if c.startswith(channel)]
        df_pbb.loc[~temporalFilter[channel],bands_for_channel] = pd.NA # remove all pruned entries
        df[channel] = df_pbb[bands_for_channel].mean(axis=1)

    average_channel_power = df.mean()[channels]
    # print(average_channel_power)

    df = pd.DataFrame(index=df_pbb.index)
    for band in bands:
        channels_with_band = [c for c in df_pbb.columns if c.endswith(band) and c.split("_")[0] not in removedChannels]
        df[band] = df_pbb[channels_with_band].mean(axis=1)
    average_band_power = df.mean()[bands_ordered]

    # TODO: split into low(0.3-0.6), medium(0.6-0.7), high(0.7-1.0)
    df_calm = pd.read_json(files["calm"])
    avg_calm_score = df_calm["probability"].mean()
    time_spent_calm = (df_calm["probability"] > 0.3).sum() / len(df_calm)

    df_focus = pd.read_json(files["focus"])
    avg_focus_score = df_focus["probability"].mean()
    time_spent_focused = (df_focus["probability"] > 0.3).sum() / len(df_focus)

    unix_timestamp = int(df_pbb.index[0].timestamp())
    if len(channels)==0: print(df_pbb.index[0].date(),df_pbb.index[0].time(),"WARNING all channels removed")

    
    return {
        "timestamp": unix_timestamp,
        "local_date": unix_to_localdate(unix_timestamp),
        "local_timeofday": unix_to_period(unix_timestamp),
        "duration": df_pbb.index[-1] - df_pbb.index[0],
        "avg_power_per_channel_by_band": {
            channel: {band: df_pbb[channel + "_" + band].mean() for band in bands}
            for channel in channels
        },
        "avg_power_by_band": dict(average_band_power),
        "avg_power_by_channel": dict(average_channel_power),
        "avg_calm_score": avg_calm_score,
        "avg_focus_score": avg_focus_score,
        "time_spent_calm": time_spent_calm,
        "time_spent_focused": time_spent_focused,
        # `relative_power` keys are 2-tuples (band1, band2), values are ratios
        # maybe doesn't need to be computed here,
        # can be computed later from `avg_power_by_band`
        # "relative_power": {},
        #"signal_quality" based on the signal quality data
    }


debug = 0
for x in ["morning","evening"]:
    for y in files.extractByTags(x):
        if x not in dataAgg: dataAgg[x] = {}
        debug+=1
        if debug == 12 or True:
            dataStage = load_session_summery(files.extractById(y),.8,epochSize=30)["avg_power_per_channel_by_band"]
            # print(dataStage)
            # print(y)
            # file = files.extractById(y)
            # df_pbb = pd.read_json(file["powerByBand"])
            # df_sigQ =  pd.read_json(file["signalQuality"])
            # df_pbb.set_index("unixTimestamp", inplace=True)
            # df_sigQ.set_index("unixTimestamp", inplace=True,drop=False)
            # df_pbb.index = pd.to_datetime(df_pbb.index, unit="s")
            # df_sigQ.index = pd.to_datetime(df_sigQ.index, unit="s")

            # df_pbb["Quality"] = pd.NA

            # common = set(df_sigQ.index).intersection(set(df_pbb.index))
            # df_pbb.loc[list(common),"Quality"] = df_sigQ.drop_duplicates(subset=["unixTimestamp"]).loc[list(common),"PO4_status"]
            # df_pbb[["Quality","CP3_alpha","C3_theta"]].head(40)
            # print(df_pbb[(df_pbb["Quality"]=="good") | (df_pbb["Quality"]=="great")][[x for x in df_pbb.columns if "PO4" in x ]].max(numeric_only=True))

            if list(dataStage.keys())!=[]: dataAgg[x][y] = dataStage

1674926196 Meta available but no associated recordings
C4 0.0016666666666667052 lost
F5 0.4504166666666667 lost
CP3 0.0016666666666667052 lost
PO4 0.0016666666666667052 lost
PO3 0.0016666666666667052 lost
C3 0.0016666666666667052 lost
Channels F6 removed
C4 0.8814814814814815 lost
F5 0.6333333333333333 lost
F6 0.6333333333333333 lost
CP4 0.6333333333333333 lost
CP3 0.6333333333333333 lost
PO4 0.6333333333333333 lost
PO3 0.6333333333333333 lost
C3 0.8814814814814815 lost
C4 0.18470149253731338 lost
F5 0.3843283582089553 lost
F6 0.220771144278607 lost
CP4 0.07400497512437809 lost
CP3 0.18470149253731338 lost
PO4 0.0 lost
PO3 0.0 lost
C3 0.18470149253731338 lost
C4 0.05087572977481236 lost
F5 0.14970809007506258 lost
F6 0.24979149291075897 lost
CP4 0.0008340283569641116 lost
CP3 0.5500417014178482 lost
PO4 0.09966638865721433 lost
PO3 0.05087572977481236 lost
C3 0.05087572977481236 lost
C4 0.5 lost
F5 0.0 lost
F6 0.0 lost
CP4 0.0 lost
CP3 0.0 lost
PO4 0.0 lost
PO3 0.20016680567139278 lost

In [322]:
id =   1674719194
# id = 1674719194
# id = 1685033809
file = files.extractById(id)
df_pbb = pd.read_json(file["powerByBand"])
df_sigQ =  pd.read_json(file["signalQuality"])
df_pbb.set_index("unixTimestamp", inplace=True)
df_sigQ.set_index("unixTimestamp", inplace=True,drop=False)
df_pbb.index = pd.to_datetime(df_pbb.index, unit="s")
df_sigQ.index = pd.to_datetime(df_sigQ.index, unit="s")

df_pbb["Quality"] = pd.NA

common = set(df_sigQ.index).intersection(set(df_pbb.index))
df_pbb.loc[list(common),"Quality"] = df_sigQ.drop_duplicates(subset=["unixTimestamp"]).loc[list(common),"PO4_status"]

df_pbb[["Quality","CP3_alpha","C3_theta"]].head(40)

# df_pbb[(df_pbb["Quality"]=="good") | (df_pbb["Quality"]=="great")][[x for x in df_pbb.columns if "PO4" in x ]].max(numeric_only=True)


Unnamed: 0_level_0,Quality,CP3_alpha,C3_theta
unixTimestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2023-01-26 07:46:35,great,2.941192,5.532516
2023-01-26 07:46:35,great,2.488266,5.877565
2023-01-26 07:46:35,great,2.043517,4.620965
2023-01-26 07:46:36,great,0.729971,2.434084
2023-01-26 07:46:36,great,0.980746,0.910376
2023-01-26 07:46:36,great,0.722352,1.586728
2023-01-26 07:46:36,great,2.296571,5.65726
2023-01-26 07:46:37,bad,2.857461,14.848285
2023-01-26 07:46:37,bad,6.502483,3.641605
2023-01-26 07:46:37,bad,12.423698,16.205206


In [330]:
dataAgg = {k : { x : [{y+"_"+z : dataAgg[k][x][y][z] for z in dataAgg[k][x][y]} for y in dataAgg[k][x]] for x in dataAgg[k]} for k in dataAgg}
for x in dataAgg:
    for y in dataAgg[x]:
        accumulated = {}
        for z in dataAgg[x][y]:
            accumulated.update(z)
        dataAgg[x][y] = accumulated

Simple Aggregated Trials

In [350]:
morning = pd.DataFrame.from_dict(dataAgg["morning"],orient="index")
morningAgg = pd.DataFrame()
chan, band = zip(*[x.split("_") for x in morning.columns])
chan = set(chan)
band = set(band)
for x in band:
    morningAgg[x] = morning[[col for col in morning.columns if x in col]].mean(axis=1)

# print("std",morning.std(axis=1, numeric_only=True))
# print("mean",morning.mean(axis=1, numeric_only=True))

print("std",morningAgg.std(axis=0, numeric_only=True))
print("mean",morningAgg.mean(axis=0, numeric_only=True))
# morningAgg

std delta    1.294381
beta     0.402253
alpha    0.586416
theta    1.045804
gamma    0.076654
dtype: float64
mean delta    4.294237
beta     1.253578
alpha    2.463736
theta    3.757177
gamma    0.231219
dtype: float64


In [352]:
evening = pd.DataFrame.from_dict(dataAgg["evening"],orient="index")

eveningAgg = pd.DataFrame()
chan, band = zip(*[x.split("_") for x in evening.columns])
chan = set(chan)
band = set(band)

for x in band:
    eveningAgg[x] = evening[[col for col in evening.columns if x in col]].mean(axis=1)

# print("std",evening.std(axis=0, numeric_only=True))
#print("mean",evening.mean(axis=0, numeric_only=True))
print("std",eveningAgg.std(axis=0, numeric_only=True))
print("mean",eveningAgg.mean(axis=0, numeric_only=True))
# evening

std delta    1.487258
beta     0.266771
alpha    0.505952
theta    1.077685
gamma    0.042930
dtype: float64
mean delta    4.749583
beta     1.312255
alpha    2.526306
theta    4.004862
gamma    0.244501
dtype: float64


In [324]:
evening = pd.DataFrame.from_dict(dataAgg["evening"])
# print("std",evening.std(axis=1, numeric_only=True))
# print("mean",evening.mean(axis=1, numeric_only=True))
evening

Unnamed: 0,1674454790,1674539317,1674613052,1674630344,1674719194,1674791018,1674886663,1674959009,1674966879,1675044968,1675054666,1675142192
C4_delta,2.446844,2.76915,4.447463,4.131151,3.286249,1.838343,3.94955,4.496858,3.768461,4.11146,3.525678,11.846066
C4_beta,0.999524,1.090198,1.388719,1.006178,1.332103,0.769659,1.08615,1.263584,1.345298,1.273389,1.308261,2.109239
C4_alpha,1.869801,1.811696,2.719571,1.981798,2.323795,1.401233,2.223268,2.511642,2.26048,2.207536,2.247998,4.399217
C4_theta,2.339327,2.525544,4.097553,3.481412,3.292033,1.728348,3.530035,4.018669,3.472172,3.565696,3.230459,9.087194
C4_gamma,0.188591,0.22668,0.253379,0.205236,0.286666,0.14907,0.202204,0.227321,0.24876,0.244705,0.260948,0.330448
F5_delta,3.393191,5.661824,2.882594,,3.328692,1.733902,3.765745,2.219505,2.812855,2.007267,10.323107,5.961563
F5_beta,1.171619,1.154628,0.992455,,1.228439,0.720974,0.784362,0.736382,0.974304,1.009121,1.866341,1.255727
F5_alpha,1.876202,2.296687,1.743781,,2.051093,0.996176,1.565027,1.368441,1.515929,1.367208,3.832078,2.431269
F5_theta,2.989307,4.479714,2.627614,,3.037203,1.496166,3.20454,2.100016,2.543277,1.867198,7.959754,4.715405
F5_gamma,0.228229,0.191163,0.196343,,0.270723,0.154463,0.164837,0.139846,0.198592,0.212098,0.296204,0.215956
