In [1]:
import numpy as np
import scipy
from scipy.io import loadmat
import mne, glob
from mne_features.feature_extraction import extract_features
import pandas as pd
from ipynb.fs.full.fullDataExtraction import getRawArrayData
from ipynb.fs.full.fullDataExtraction import extarctChannelNames
import plotly.express as px
import plotly.offline as py
import plotly.graph_objects as go
import plotly
import math
from tensorflow import keras
import dash_bootstrap_components as dbc
from ipynb.fs.full.fullDataExtraction import compute_diffEnt

In [2]:
def makeEpochs(rawArrays):
    sliced_epochs={}
    for clip in rawArrays:
        sliced_epochs[clip]= mne.make_fixed_length_epochs(rawArrays[clip], duration=1, preload=True,verbose=0)
    return sliced_epochs

def getRawArrays(matfile):
    del matfile["__header__"]
    del matfile["__version__"]
    del matfile["__globals__"]
    clip_info={}
    indexs=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    keyName=list(matfile.keys())[0][:-1]
    channelNamesExtraction=extarctChannelNames(pd.read_excel("Preprocessed_EEG/channel-order.xlsx"))
    info=mne.create_info(channelNamesExtraction,200,'eeg')
    for ind,i in enumerate(indexs):
        rawData=matfile[keyName + str(i)]
        clip_info[ind] = mne.io.RawArray(rawData,info,verbose=0)
    return clip_info

def makeValueMatrix(currentAnalysed):
    result=pd.read_excel("Preprocessed_EEG/channel-order(topo).xlsx").to_numpy().astype(np.float32)
    #tracker value
    k=0
    for i in range(0,result.shape[0]):
        for j in range (0,result.shape[1]):
            if not(math.isnan(result[i][j])):
                result[i][j]=currentAnalysed[k]
                k+=1
    
    return result

def makePostionList(data):
    position=[]
    for column in range (0,500):
        position.append(makeValueMatrix(data[:,column]))
    return position

def ExtractFeatures(sliced_epochs,features):
    combined = np.zeros((1,len(features)*62))
    for cut in sliced_epochs:
        epoch_array=mne.Epochs.get_data(sliced_epochs[cut])
        extracted_data=extract_features(epoch_array,200,features)
        combined = np.vstack((combined,extracted_data))
    combined = np.delete(combined, 0, axis=0)
    return combined


In [3]:
channels=pd.read_excel("Preprocessed_EEG/channel-order(viz).xlsx")
channelNames=channels.iloc[:,0]
channelNames=np.ndarray.tolist(pd.Series.to_numpy(channelNames))
channelNames.insert(0,"Fp1")
ch_types = ['eeg'] * len(channelNames)
info = mne.create_info(channelNames, ch_types=ch_types, sfreq=200)
info.set_montage('standard_1020')

subject=loadmat('1_20131107.mat')
data=np.delete(subject['djc_eeg1'], 61, 0)
data=np.delete(data, 57, 0)
evoked=mne.EvokedArray(data,info)
transposedData=np.transpose(data)
df = pd.DataFrame(transposedData, columns = channelNames)
df = df.head(500)

model = keras.models.load_model("emotionPredictionModel")
rawArrays=getRawArrays(subject)
slicedEpochs=makeEpochs(rawArrays)
extractedData = ExtractFeatures(slicedEpochs,['hjorth_mobility', ('diffEnt',compute_diffEnt), 'rms', 'skewness'])
points=makePostionList(data)

y_pred = model.predict(extractedData)
y_pred = (y_pred > 0.5).astype(int)
y_pred = y_pred.tolist()




In [4]:
from dash import Dash, dcc, html, Input, Output, State, ctx
from jupyter_dash import JupyterDash
import base64
import dash
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.LUX])
slider = html.Div(
    [
            dcc.Slider(
            min=0,
            max=499,
            step=1,
            value=7,
            id="slider_time",
            marks=None,
            tooltip={"always_visible": False},
            vertical=True
        ),
    ],
)

#slider components
sliderLabel=dbc.Row(html.Label("Select time frame"),style ={"text-align": "center","padding-left": "10px"})
sliderColumn= dbc.Col([slider],width=2,style = {"height":"30%","text-align": "center",'padding-left':'50%', 'padding-right':'50%'})
heatMapRefButton= dbc.Row(html.Div([html.A(dbc.Button('HeatMap',id="heatbtn"),href='#heatmap')]),style = {"padding": "15rem 1rem 0rem 1rem"})
timeInput= dbc.Row(html.Div(dbc.Input(id="input_time".format(int), type="number", min=0, max=499, value=7)),style = {"padding": "1rem 0rem 3rem 1.2rem"})


app.layout=html.Div([
        
        html.H1('BarChart/LineChart'),
        dcc.Tabs(id='Chart_tabs', value='BarChart', children=[
            dcc.Tab(label='BarChart', value='BarChart'),
            dcc.Tab(label='BarChart', value='LineChart'),
        ]),
            dbc.Col([
            sliderLabel,
            sliderColumn,
            heatMapRefButton,
            timeInput
            ],style ={"top": 0,"left": 0,"bottom": 0,"width": "10rem","padding": "6rem 1rem","position": "fixed"},width=2),
    html.Div(id='BarLineChartTabs')
])

#function for updating emotion prediction
@app.callback(Output("emotion", "children"),
              Input("slider_time", "value"))
def updatePrediction(time_instance):
    if (y_pred[time_instance] == [1,0,0]):
        emotion="positive"
    elif (y_pred[time_instance] == [0,1,0]):
        emotion="neutral"
    else:
        emotion="negative"  
    emotionDiv=html.H1(["Current Emotion",html.Br(),html.Br(),emotion])
    return emotionDiv

@app.callback(
    Output('BarLineChartTabs', 'children'),
    Input('Chart_tabs', 'value')
)

#function where information is displayed
def render_content(tab):
    if tab == 'BarChart':
        return html.Div(children=[
            html.H3('Bar Chart'),
            dbc.Row(
                #Slider Column
                dbc.Col([
                            dbc.Row([  
                                    html.H1(id="emotion")
                                    ], style ={"text-align": "center",'padding-top':'5%'}),
                            dbc.Row([
                                    dbc.Col(dcc.Graph(id="barChartFrontal"),width=4),
                                    dbc.Col(dcc.Graph(id="barChartCentral"),width=4),
                                    dbc.Col(dcc.Graph(id="barChartRightTemporal"),width=4)
                                    ]),
                            dbc.Row([
                                    dbc.Col(dcc.Graph(id="barChartParietal"),width=4), 
                                    dbc.Col(dcc.Graph(id="barChartOccipital"),width=4),
                                    dbc.Col(dcc.Graph(id="barChartLeftTemporal"),width=4)
                                    ]),
                        ],width=10,style={"margin-left": "12rem","margin-right": "2rem"})
                    )
            ])
    
    elif tab == 'LineChart':
            return html.Div(children=[
            html.H3('Bar Chart'),
            dbc.Row(
                #slider column
                dbc.Col([
                            dbc.Row([  
                                    html.H1(id="emotion")
                                    ], style ={"text-align": "center",'padding-top':'5%'}),
                            dbc.Row([
                                    dbc.Col(dcc.Graph(id="lineChartFrontal"),width=4),
                                    dbc.Col(dcc.Graph(id="lineChartCentral"),width=4),
                                    dbc.Col(dcc.Graph(id="lineChartRightTemporal"),width=4)
                                    ]),
                            dbc.Row([
                                    dbc.Col(dcc.Graph(id="lineChartParietal"),width=4), 
                                    dbc.Col(dcc.Graph(id="lineChartOccipital"),width=4),
                                    dbc.Col(dcc.Graph(id="lineChartLeftTemporal"),width=4)
                                    ]),
                        ],width=10,style={"margin-left": "12rem","margin-right": "2rem"})
                    )
            ])
            
 #function for updating the bar charts
@app.callback(Output("barChartFrontal", "figure"),
              Output("barChartCentral", "figure"),
              Output("barChartParietal", "figure"),
              Output("barChartOccipital", "figure"),
              Output("barChartLeftTemporal", "figure"),
              Output("barChartRightTemporal", "figure"),
              Input("slider_time", "value"))
def updateBarCharts(time_instance):
                frontal=['Fp1', 'Fpz', 'Fp2','AF3', 'AF4','F1', 'Fz', 'F2']
                central=['FC1', 'FCz', 'FC2','C3', 'C1', 'Cz', 'C2', 'C4']
                parietal=['CP3', 'CP1', 'CPz', 'CP2', 'CP4','P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6']
                occipital=['PO7', 'PO5', 'PO3','POz', 'PO4', 'PO6', 'PO8','O1', 'Oz', 'O2']
                left=['F4', 'F6', 'F8','FC4','FC6', 'FT8','C6','CP6', 'TP8','P8']
                right=['F7', 'F5', 'F3','FT7', 'FC5', 'FC3','C5','TP7', 'CP5','P7']
                axis_names=dict(y="Voltage", color="Channel Names")
    
                #barChart
                ser=df[frontal].iloc[time_instance]
                barChartFrontal= px.bar(x=frontal,y=ser.values,color=frontal,title="Frontal",labels=axis_names)
    
                ser=df[central].iloc[time_instance]
                barChartCentral= px.bar(x=central,y=ser.values,color=central,title="Central",labels=axis_names)
    
                ser=df[parietal].iloc[time_instance]
                barChartParietal= px.bar(x=parietal,y=ser.values,color=parietal,title="Parietal",labels=axis_names)
    
                ser=df[occipital].iloc[time_instance]
                barChartOccipital= px.bar(x=occipital,y=ser.values,color=occipital,title="Occipital",labels=axis_names)
    
                ser=df[left].iloc[time_instance]
                barChartLeftTemporal= px.bar(x=left,y=ser.values,color=left,title="Left Temporal",labels=axis_names)
    
                ser=df[right].iloc[time_instance]
                barChartRightTemporal= px.bar(x=right,y=ser.values,color=right,title="Right Temporal",labels=axis_names)

                return barChartFrontal,barChartCentral,barChartParietal,barChartOccipital,barChartLeftTemporal,barChartRightTemporal
            
#funciton for updating line chart
@app.callback(Output("lineChartFrontal", "figure"),
                          Output("lineChartCentral", "figure"),
                          Output("lineChartParietal", "figure"),
                          Output("lineChartOccipital", "figure"),
                          Output("lineChartLeftTemporal", "figure"),
                          Output("lineChartRightTemporal", "figure"),
                          Input("slider_time", "value"))

def updateLineChart(time_instance):
                frontal=['Fp1', 'Fpz', 'Fp2','AF3', 'AF4','F1', 'Fz', 'F2']
                central=['FC1', 'FCz', 'FC2','C3', 'C1', 'Cz', 'C2', 'C4']
                parietal=['CP3', 'CP1', 'CPz', 'CP2', 'CP4','P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6']
                occipital=['PO7', 'PO5', 'PO3','POz', 'PO4', 'PO6', 'PO8','O1', 'Oz', 'O2']
                left=['F4', 'F6', 'F8','FC4','FC6', 'FT8','C6','CP6', 'TP8','P8']
                right=['F7', 'F5', 'F3','FT7', 'FC5', 'FC3','C5','TP7', 'CP5','P7']
    
                lineChartFrontal= px.line(df[frontal].head(time_instance), title="Frontal")
                lineChartCentral= px.line(df[central].head(time_instance), title="Central")
                lineChartParietal= px.line(df[parietal].head(time_instance), title="Parietal")
                lineChartOccipital= px.line(df[occipital].head(time_instance), title="Occipital")
                lineChartLeftTemporal= px.line(df[left].head(time_instance), title="Left")
                lineChartRightTemporal= px.line(df[right].head(time_instance), title="Right")
    
    
                return lineChartFrontal,lineChartCentral,lineChartParietal,lineChartOccipital,lineChartLeftTemporal,lineChartRightTemporal        
        
        
#sync Input and Slider Value
@app.callback(
    Output("input_time", "value"),
    Output("slider_time", "value"),
    Input("input_time", "value"),
    Input("slider_time", "value"),
)

def callbacktest(input_value, slider_value):
    ctx = dash.callback_context
    trigger_id = ctx.triggered[0]["prop_id"].split(".")[0]
    value = input_value if trigger_id == "input_time" else slider_value
    return value, value

@app.callback(
                Output("heatmap","style"),
                Input('heatbtn', 'n_clicks'),
             )
def showHeatMap(buttonHeat):
    if 'heatbtn' == ctx.triggered_id:
        return 

#function for updating heatmap when slider is moved
@app.callback(Output("heatmap", "figure"),
              Input("slider_time", "value"))
def updateHeatMap(time_instance):
    heatMap = px.imshow(points[time_instance])
    heatMap.update_xaxes(showticklabels=False)
    heatMap.update_yaxes(showticklabels=False)
    heatMap.update_layout(xaxis=dict(scaleanchor='y', constrain='domain'),coloraxis=dict(colorbar=dict(orientation='h',thickness=20,len=0.5)))
    return heatMap


if __name__ == "__main__":
    app.run_server(debug=True)

Dash app running on http://127.0.0.1:8050/
