In [111]:
import pandas as pd
import eeg
from eeg import unix_to_localdate, unix_to_period
import datetime
import scipy.fftpack
import numpy as np
import matplotlib.pyplot as plt


files = eeg.extractBundledEEG("../../scripts/data/")
files.addMeta("../../scripts/data/eeg-restingstate/events.csv","timeofday")
files.prune()
# files.mergeTagsWithRegex("[cC]lose","Eyes Closed")
files.categories
dataAgg = dict()

bands_ordered = ["delta", "theta", "alpha", "beta", "gamma"]

#TODO epoch generator

def load_session_epochs():
    pass

def load_session_summery(files: dict, qualityCutoffFilter: int = 0, epochSize: int = -1, returnEpoched = False) -> dict: # summery or epoch return
    """
    Takes a session of EEG data, computes some metrics, and returns them.
    qualityCutoffFilter: percentage of time electrode data is marked as "good" or "great" required for it to be included in output and analyticsl, 0 capture everything
    filterSampleWindow: window size to sample filters 

    Potential metrics:
     - average power by band
     - average power by channel
     - relative power by band
     - average focus/calm score
    """
    # Best channels are usually: CP3, CP4, PO3, PO4

    # NOTE: unixTimestamps are int/seconds but samples more often than 1Hz,
    #       so several rows per timestamp and missing sub-second resolution.
    df_pbb = pd.read_json(files["powerByBand"])
    df_sigQ =  pd.read_json(files["signalQuality"])
    df_pbb.set_index("unixTimestamp", inplace=True)
    df_sigQ.set_index("unixTimestamp", inplace=True)
    df_pbb.index = pd.to_datetime(df_pbb.index, unit="s")
    df_sigQ.index = pd.to_datetime(df_sigQ.index, unit="s")


    if epochSize == -1: epochSize = (df_pbb.index[df_pbb.shape[0]-1]-df_pbb.index[0]).seconds

    # we have to deal with the channels, such as CP3_alpha, CP3_beta, etc.
    # for now, we will just average them all together
    channels, bands = zip(*[c.split("_") for c in df_pbb.columns])
    channels, bands = list(set(channels)), list(set(bands))

    temporalFilter = {x : pd.Series([True] * df_pbb.shape[0],index=df_pbb.index) for x in channels}

    removedChannels = []
    epochSize = datetime.timedelta(seconds=epochSize)

    newTime = df_pbb.index[0] + epochSize
    oldTime = df_pbb.index[0]

    endTime = df_pbb.index[df_pbb.shape[0]-1]

    if(newTime>endTime): newTime = endTime
    finalLoop = False

    while(endTime>=newTime):
        sigSamp = df_sigQ[(df_sigQ.index < newTime) & (df_sigQ.index >= oldTime)]
        pbbSamp = (df_pbb.index < newTime) & (df_pbb.index >= oldTime)
        if sigSamp.empty and finalLoop:
            if percentage_good<qualityCutoffFilter: temporalFilter[channel][pbbSamp] = False
            break

        for channel in channels:
            col = channel + "_status"
            # epochDf 
            channel_states = sigSamp[col].value_counts()
            no_of_okay_samples = 0
            if 'good' in channel_states:
                no_of_okay_samples += channel_states['good']
            if 'great' in channel_states:
                no_of_okay_samples += channel_states['great']

            percentage_good = no_of_okay_samples / sigSamp.shape[0]
            if percentage_good<qualityCutoffFilter: temporalFilter[channel][pbbSamp] = False

        if((newTime+epochSize)>endTime): 
            oldTime = newTime
            newTime = endTime
            finalLoop = True
        else:
            oldTime = newTime
            newTime = newTime + epochSize
    

    
    for x in channels: 
        # print(temporalFilter[x].value_counts())
        if True not in temporalFilter[x].value_counts():
            removedChannels.append(x)
            channels.remove(x)
        elif(temporalFilter[x].value_counts()[True]<10):
            removedChannels.append(x)
            channels.remove(x)

    if removedChannels: print("Channels",*removedChannels,"removed")

    df = pd.DataFrame(index=df_pbb.index)
    for channel in channels:
        bands_for_channel = [c for c in df_pbb.columns if c.startswith(channel)]
        df_pbb.loc[temporalFilter[channel],bands_for_channel] = pd.NA # remove all pruned entries
        df[channel] = df_pbb[bands_for_channel].mean(axis=1)

    average_channel_power = df.mean()[channels]
    print(average_channel_power)

    df = pd.DataFrame(index=df_pbb.index)
    for band in bands:
        channels_with_band = [c for c in df_pbb.columns if c.endswith(band) and c.split("_")[0] not in removedChannels]
        df[band] = df_pbb[channels_with_band].mean(axis=1)
    average_band_power = df.mean()[bands_ordered]



    # TODO: split into low(0.3-0.6), medium(0.6-0.7), high(0.7-1.0)
    df_calm = pd.read_json(files["calm"])
    avg_calm_score = df_calm["probability"].mean()
    time_spent_calm = (df_calm["probability"] > 0.3).sum() / len(df_calm)

    df_focus = pd.read_json(files["focus"])
    avg_focus_score = df_focus["probability"].mean()
    time_spent_focused = (df_focus["probability"] > 0.3).sum() / len(df_focus)

    unix_timestamp = int(df_pbb.index[0].timestamp())
    if len(channels)==0: print(df_pbb.index[0].date(),df_pbb.index[0].time(),"WARNING all channels removed")

    
    return {
        "timestamp": unix_timestamp,
        "local_date": unix_to_localdate(unix_timestamp),
        "local_timeofday": unix_to_period(unix_timestamp),
        "duration": df_pbb.index[-1] - df_pbb.index[0],
        "avg_power_per_channel_by_band": {
            channel: {band: df_pbb[channel + "_" + band].mean() for band in bands}
            for channel in channels
        },
        "avg_power_by_band": dict(average_band_power),
        "avg_power_by_channel": dict(average_channel_power),
        "avg_calm_score": avg_calm_score,
        "avg_focus_score": avg_focus_score,
        "time_spent_calm": time_spent_calm,
        "time_spent_focused": time_spent_focused,
        # `relative_power` keys are 2-tuples (band1, band2), values are ratios
        # maybe doesn't need to be computed here,
        # can be computed later from `avg_power_by_band`
        # "relative_power": {},
        #"signal_quality" based on the signal quality data
    }

for x in ["morning","evening"]:
    for y in files.extractByTags(x):
        if x not in dataAgg: dataAgg[x] = {}
        dataStage = load_session_summery(files.extractById(y),.8)["avg_power_per_channel_by_band"]
        print(dataStage)
        if list(dataStage.keys())!=[]: dataAgg[x][y] = dataStage

1674926196 Meta available but no associated recordings
Channels F5 removed
C4           NaN
F6     12.741566
CP4          NaN
CP3          NaN
PO4          NaN
PO3          NaN
C3           NaN
dtype: float64
{'C4': {'delta': nan, 'beta': nan, 'alpha': nan, 'theta': nan, 'gamma': nan}, 'F6': {'delta': 25.94566398481464, 'beta': 3.8333148372235404, 'alpha': 7.9389064783942525, 'theta': 25.080879744974574, 'gamma': 0.9090644026875471}, 'CP4': {'delta': nan, 'beta': nan, 'alpha': nan, 'theta': nan, 'gamma': nan}, 'CP3': {'delta': nan, 'beta': nan, 'alpha': nan, 'theta': nan, 'gamma': nan}, 'PO4': {'delta': nan, 'beta': nan, 'alpha': nan, 'theta': nan, 'gamma': nan}, 'PO3': {'delta': nan, 'beta': nan, 'alpha': nan, 'theta': nan, 'gamma': nan}, 'C3': {'delta': nan, 'beta': nan, 'alpha': nan, 'theta': nan, 'gamma': nan}}
Channels C4 F6 CP3 PO3 removed
F5     11.517261
CP4    10.775826
PO4     9.455782
C3     13.341053
dtype: float64
{'F5': {'delta': 21.653144625111526, 'beta': 5.095371123229

In [103]:
dataAgg = {k : { x : [{y+"_"+z : dataAgg[k][x][y][z] for z in dataAgg[k][x][y]} for y in dataAgg[k][x]] for x in dataAgg[k]} for k in dataAgg}
for x in dataAgg:
    for y in dataAgg[x]:
        accumulated = {}
        for z in dataAgg[x][y]:
            accumulated.update(z)
        dataAgg[x][y] = accumulated

Simple Aggregated Trials

In [104]:
morning = pd.DataFrame.from_dict(dataAgg["morning"])
print("std",morning.std(axis=1, numeric_only=True))
print("mean",morning.mean(axis=1, numeric_only=True))
# print(morning)

std C4_delta    NaN
C4_beta     NaN
C4_alpha    NaN
C4_theta    NaN
C4_gamma    NaN
F5_delta    NaN
F5_beta     NaN
F5_alpha    NaN
F5_theta    NaN
F5_gamma    NaN
F6_delta    NaN
F6_beta     NaN
F6_alpha    NaN
F6_theta    NaN
F6_gamma    NaN
CP4_delta   NaN
CP4_beta    NaN
CP4_alpha   NaN
CP4_theta   NaN
CP4_gamma   NaN
CP3_delta   NaN
CP3_beta    NaN
CP3_alpha   NaN
CP3_theta   NaN
CP3_gamma   NaN
PO4_delta   NaN
PO4_beta    NaN
PO4_alpha   NaN
PO4_theta   NaN
PO4_gamma   NaN
PO3_delta   NaN
PO3_beta    NaN
PO3_alpha   NaN
PO3_theta   NaN
PO3_gamma   NaN
C3_delta    NaN
C3_beta     NaN
C3_alpha    NaN
C3_theta    NaN
C3_gamma    NaN
dtype: float64
mean C4_delta    NaN
C4_beta     NaN
C4_alpha    NaN
C4_theta    NaN
C4_gamma    NaN
F5_delta    NaN
F5_beta     NaN
F5_alpha    NaN
F5_theta    NaN
F5_gamma    NaN
F6_delta    NaN
F6_beta     NaN
F6_alpha    NaN
F6_theta    NaN
F6_gamma    NaN
CP4_delta   NaN
CP4_beta    NaN
CP4_alpha   NaN
CP4_theta   NaN
CP4_gamma   NaN
CP3_delta   NaN


In [105]:
evening = pd.DataFrame.from_dict(dataAgg["evening"])
print("std",evening.std(axis=1, numeric_only=True))
print("mean",evening.mean(axis=1, numeric_only=True))

std C4_delta    NaN
C4_beta     NaN
C4_alpha    NaN
C4_theta    NaN
C4_gamma    NaN
F5_delta    NaN
F5_beta     NaN
F5_alpha    NaN
F5_theta    NaN
F5_gamma    NaN
F6_delta    NaN
F6_beta     NaN
F6_alpha    NaN
F6_theta    NaN
F6_gamma    NaN
CP4_delta   NaN
CP4_beta    NaN
CP4_alpha   NaN
CP4_theta   NaN
CP4_gamma   NaN
CP3_delta   NaN
CP3_beta    NaN
CP3_alpha   NaN
CP3_theta   NaN
CP3_gamma   NaN
PO4_delta   NaN
PO4_beta    NaN
PO4_alpha   NaN
PO4_theta   NaN
PO4_gamma   NaN
PO3_delta   NaN
PO3_beta    NaN
PO3_alpha   NaN
PO3_theta   NaN
PO3_gamma   NaN
C3_delta    NaN
C3_beta     NaN
C3_alpha    NaN
C3_theta    NaN
C3_gamma    NaN
dtype: float64
mean C4_delta    NaN
C4_beta     NaN
C4_alpha    NaN
C4_theta    NaN
C4_gamma    NaN
F5_delta    NaN
F5_beta     NaN
F5_alpha    NaN
F5_theta    NaN
F5_gamma    NaN
F6_delta    NaN
F6_beta     NaN
F6_alpha    NaN
F6_theta    NaN
F6_gamma    NaN
CP4_delta   NaN
CP4_beta    NaN
CP4_alpha   NaN
CP4_theta   NaN
CP4_gamma   NaN
CP3_delta   NaN
