In [1]:
import numpy as np
import scipy
from scipy.io import loadmat
import mne, glob
from mne_features.feature_extraction import extract_features
import pandas as pd
from ipynb.fs.full.fullDataExtraction import getRawArrayData
from ipynb.fs.full.fullDataExtraction import extarctChannelNames
import plotly.express as px
import plotly.offline as py
import plotly.graph_objects as go
import plotly
import math
from tensorflow import keras
import dash_bootstrap_components as dbc
from ipynb.fs.full.fullDataExtraction import compute_diffEnt
from dash import Dash, dcc, html, Input, Output, State, ctx
from jupyter_dash import JupyterDash
from dash.exceptions import PreventUpdate
import dash

In [2]:
def makeEpochs(rawArrays):
    sliced_epochs={}
    for clip in rawArrays:
        sliced_epochs[clip]= mne.make_fixed_length_epochs(rawArrays[clip], duration=1, preload=True,verbose=0)
    return sliced_epochs

def getRawArrays(matfile):
    del matfile["__header__"]
    del matfile["__version__"]
    del matfile["__globals__"]
    clip_info={}
    indexs=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    keyName=list(matfile.keys())[0][:-1]
    channelNamesExtraction=extarctChannelNames(pd.read_excel("Preprocessed_EEG/channel-order.xlsx"))
    info=mne.create_info(channelNamesExtraction,200,'eeg')
    for ind,i in enumerate(indexs):
        rawData=matfile[keyName + str(i)]
        clip_info[ind] = mne.io.RawArray(rawData,info,verbose=0)
    return clip_info

def makeValueMatrix(currentAnalysed):
    result=pd.read_excel("Preprocessed_EEG/channel-order(topo).xlsx").to_numpy().astype(np.float32)
    #tracker value
    k=0
    for i in range(0,result.shape[0]):
        for j in range (0,result.shape[1]):
            if not(math.isnan(result[i][j])):
                result[i][j]=currentAnalysed[k]
                k+=1
    
    return result

def makePostionList(data):
    position=[]
    for column in range (0,500):
        position.append(makeValueMatrix(data[:,column]))
    return position

def ExtractFeatures(sliced_epochs,features):
    combined = np.zeros((1,len(features)*62))
    for cut in sliced_epochs:
        epoch_array=mne.Epochs.get_data(sliced_epochs[cut])
        extracted_data=extract_features(epoch_array,200,features)
        combined = np.vstack((combined,extracted_data))
    combined = np.delete(combined, 0, axis=0)
    return combined

channels=pd.read_excel("Preprocessed_EEG/channel-order(viz).xlsx")
channelNames=channels.iloc[:,0]
channelNames=np.ndarray.tolist(pd.Series.to_numpy(channelNames))
channelNames.insert(0,"Fp1")
ch_types = ['eeg'] * len(channelNames)

info = mne.create_info(channelNames, ch_types=ch_types, sfreq=200)
info.set_montage('standard_1020')

subject=loadmat('1_20131107.mat')
data=np.delete(subject['djc_eeg1'], 61, 0)
data=np.delete(data, 57, 0)
evoked=mne.EvokedArray(data,info)
transposedData=np.transpose(data)
df = pd.DataFrame(transposedData, columns = channelNames)
df = df.head(500)

model = keras.models.load_model("emotionPredictionModel")
rawArrays=getRawArrays(subject)
slicedEpochs=makeEpochs(rawArrays)
extractedData = ExtractFeatures(slicedEpochs,['hjorth_mobility', ('diffEnt',compute_diffEnt), 'rms', 'skewness'])
points=makePostionList(data)

y_pred = model.predict(extractedData)
y_pred = (y_pred > 0.5).astype(int)
y_pred = y_pred.tolist()

frontal=("Frontal",['Fp1', 'Fpz', 'Fp2','AF3', 'AF4','F1', 'Fz', 'F2'])
central=("Central",['FC1', 'FCz', 'FC2','C3', 'C1', 'Cz', 'C2', 'C4'])
parietal=("Parietal",['CP3', 'CP1', 'CPz', 'CP2', 'CP4','P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6'])
occipital=("Occipital",['PO7', 'PO5', 'PO3','POz', 'PO4', 'PO6', 'PO8','O1', 'Oz', 'O2'])
left=("Left Temporal",['F4', 'F6', 'F8','FC4','FC6', 'FT8','C6','CP6', 'TP8','P8'])
right=("Right Temporal",['F7', 'F5', 'F3','FT7', 'FC5', 'FC3','C5','TP7', 'CP5','P7'])




## Main component definition

### slider/input

In [204]:
slider = html.Div(
    [
            dcc.Slider(
            min=0,
            max=499,
            step=1,
            id="time",
            marks=None,
            tooltip={"always_visible": False},
        ),
    ],
    style={'padding-top':'2rem'}
)

sliderLabel=dbc.Row(html.Label("Select time frame"),style ={"text-align": "center"})

sliderColumn= dbc.Col([slider],width=2,style = {"height":"30%","text-align": "center",'padding-left':'50%', 'padding-right':'50%'})

inputLabel=html.Div([
    html.H3("Input the time frame you'd like to view"),
    "(you may also use the slider to select your time frame)"
],
    style={'padding-top':'3rem'}
)

inputCol=dbc.Col(
    dcc.Input(
        id="inputTime",
        type="number",
        placeholder="Select time",
        min=0,
        max=499
        ),
    style={'padding-top':'5rem','vertical-align': 'top','text-align': 'center'},
    width=6
)

inputField=dbc.Row([
    dbc.Col(
        inputLabel
    ),
    inputCol
])

## drop down 

In [205]:
def buildOneDdRow(options1,options2):
    return dbc.Row([
    dbc.Col([
        html.B(html.P(options1[0])),
        dcc.Dropdown(
            options1[1],
            multi=True
        )
    ],
        width=6
    ),
    dbc.Col([
        html.B(html.P(options2[0])),
        dcc.Dropdown(
            options2[1],
            multi=True
        )
    ],
        width=6
    )
],
    style={'padding-top':'1rem'})

dropDownLabel=html.Div([
    html.H3("Select the nodes you want to view"),
    "(Selecting no nodes will result in all nodes being displayed)"
],
    style={'padding-top':'1rem'}
)

dropDownRow=dbc.Row([
        dbc.Col(
            dropDownLabel
        ),
        dbc.Col([
            buildOneDdRow(frontal,central),
            buildOneDdRow(parietal,occipital),
            buildOneDdRow(left,right)
        ])
        ],
    style={'padding-left':'1rem','padding-right':'1rem'}
    )

## value pick tab

In [206]:
optionStack=dbc.Col(
                [
                    dbc.Stack([
                        dropDownRow,
                        inputField,
                        slider
                    ])
                ],
                width=6,
                style={'display':'flex', 'justifyContent':'center','padding-top':'2rem'}
            )

emotionCol=dbc.Col(
                id="emotion",
                style={'padding-top':'2rem',
                       "text-align": "center",
                       'border-radius': '0% 0% 0% 0% / 0% 0% 0% 0%',
                       'box-shadow': '20px 20px rgba(0,0,0,.15)',
                       'transition': 'all .4s ease'
                      },
                width=6
            )

valSec=dbc.Row([
            optionStack,
            emotionCol
        ])

## charts

In [207]:
heatMap=dbc.Col(
        dcc.Graph(id='heatmap'),
        style={'padding-bottom':'3rem'}
    )

barCharts=dbc.Col([
            dbc.Row([
                dbc.Col(dcc.Graph(id="barChartFrontal"),width=6),
                dbc.Col(dcc.Graph(id="barChartCentral"),width=6),
                ]),
            dbc.Row([
                dbc.Col(dcc.Graph(id="barChartParietal"),width=6),
                dbc.Col(dcc.Graph(id="barChartOccipital"),width=6),
                ]),
            dbc.Row([
                dbc.Col(dcc.Graph(id="barChartRightTemporal"),width=6),
                dbc.Col(dcc.Graph(id="barChartLeftTemporal"),width=6)
        ])
    ],
        style={
            'height':'30%',
            'overflow-x': 'hidden',
            'overflow-y': 'scroll'
        }
    )

## charts section

In [208]:
titleStyle={'border-bottom': 'thin solid black'}

mapLabel=html.Div(
    "Heat Map of Current Node Readings",
    style=titleStyle
)
   
barLabel=html.Div(
    "Bar Charts For Reading Of Each Nodes",
    style=titleStyle
)


chartSec=dbc.Row([
    dbc.Col([
            mapLabel,
            heatMap
        ],
        width=4
    ),
    dbc.Col([
            barLabel,
            barCharts
        ],
        width=8,
    )
],
    style={'padding-top':'5vh'}
)


### Helper functions

In [209]:
def decipherEmotion(time_instance):
    if time_instance!=None:
        if (y_pred[time_instance] == [1,0,0]):
            emotion="positive"
        elif (y_pred[time_instance] == [0,1,0]):
            emotion="neutral"
        else:
            emotion="negative"  
        emotionDiv=html.H1(emotion)
        return emotionDiv
    else:
        return html.H1("No time frame has been selected yet")

### Layout

In [210]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.LUX])
app.layout = dbc.Container([
    dcc.Store(id='timeStore'),
    dbc.Row([
        dbc.Row(
            html.H1('Emotion Prediction With MNE')
        ),
        dbc.Row(
            dcc.Tabs(id="tabSel", value='valSec', 
                     children=[
                        dcc.Tab(label='Values', value='valSec'),
                        dcc.Tab(label='Charts', value='chartSec'),
                        ]
                    )
            ),
        dbc.Row(id='mainContent',children=valSec)
    ]),
])

### Callbacks

In [211]:
#callback to store time
@app.callback(
    Output('timeStore','data'),
    Input('inputTime','value')
)
def storeTime(data):
    if data == None:
        raise PreventUpdate
    else:
        return data
    
#callback to render out the tabs
@app.callback(
    Output('mainContent','children'),
    Input('tabSel','value')
)
def renderContent(tab):
    if tab=='valSec':
        return valSec
    else:
        return chartSec

#callback to update the emotions display
@app.callback(
    Output('emotion','children'),
    Input('inputTime','value'),
    Input('timeStore','data')
)
def updateEmotions(curTime,storedTime):
    if curTime!= None:
        return decipherEmotion(curTime)
    else:
        return decipherEmotion(storedTime)

#callback to update charts
@app.callback(
    Output('heatmap','figure'),
    Output("barChartFrontal", "figure"),
    Output("barChartCentral", "figure"),
    Output("barChartParietal", "figure"),
    Output("barChartOccipital", "figure"),
    Output("barChartLeftTemporal", "figure"),
    Output("barChartRightTemporal", "figure"),
    Input('timeStore','data')
)
def updateChartSec(time_instance):
    if data is None:
        raise PreventUpdate
    else:
        axis_names=dict(x="Channel Names", y="Voltage", color="Channel Names")
    
        ser=df[frontal[1]].iloc[time_instance]
        barChartFrontal= px.bar(x=frontal[1],y=ser.values,color=frontal[1],title="Frontal",labels=axis_names)
    
        ser=df[central[1]].iloc[time_instance]
        barChartCentral= px.bar(x=central[1],y=ser.values,color=central[1],title="Central",labels=axis_names)
    
        ser=df[parietal[1]].iloc[time_instance]
        barChartParietal= px.bar(x=parietal[1],y=ser.values,color=parietal[1],title="Parietal",labels=axis_names)
        
        ser=df[occipital[1]].iloc[time_instance]
        barChartOccipital= px.bar(x=occipital[1],y=ser.values,color=occipital[1],title="Occipital",labels=axis_names)
    
        ser=df[left[1]].iloc[time_instance]
        barChartLeftTemporal= px.bar(x=left[1],y=ser.values,color=left[1],title="Left Temporal",labels=axis_names)
    
        ser=df[right[1]].iloc[time_instance]
        barChartRightFrontal= px.bar(x=right[1],y=ser.values,color=right[1],title="Right Temporal",labels=axis_names)
        heatMap = px.imshow(points[time_instance])
        heatMap.update_xaxes(showticklabels=False)
        heatMap.update_yaxes(showticklabels=False)
        heatMap.update_layout(xaxis=dict(scaleanchor='y', constrain='domain'),coloraxis=dict(colorbar=dict(orientation='h',thickness=20,len=1.0)))
    
        return heatMap,barChartFrontal,barChartCentral,barChartParietal,barChartOccipital,barChartLeftTemporal,barChartRightFrontal

#callback to update the slider and input
@app.callback(
    Output("inputTime","value"),
    Output("time","value"),
    Input("inputTime","value"),
    Input("time","value")
)
def updateSldierOrInput(inputVal,sliderVal):
    ctx = dash.callback_context
    trigger_id = ctx.triggered[0]["prop_id"].split(".")[0]
    value = inputVal if trigger_id == "inputTime" else sliderVal
    return value, value

### Run Server

In [212]:
if __name__ == '__main__':
    app.run_server(port=4400)

Dash app running on http://127.0.0.1:4400/
