In [None]:
# change into the root directory of the project
import os
if os.getcwd().split("/")[-1] == "dev":
    os.chdir('..')
    
# get the current directory
cwd = os.getcwd()

# Print the current working directory
print("Current working directory: {0}".format(cwd))


# Set path, where to save files of tables
PATH = cwd
HDF_PATH = PATH + '/data/hdf/'
IMG_PATH = PATH + '/images/05/'
print(HDF_PATH)


import numpy as np
import scipy
from scipy import stats


from neurolib.models.wc-adap import WCModel
from neurolib.models.wc_input import WCModel_input

from neurolib.utils.parameterSpace import ParameterSpace
from neurolib.optimize.exploration import BoxSearch
import neurolib.utils.functions as func
import neurolib.utils.devutils as du

#import neurolib.utils.brainplot as bp
import neurolib.optimize.exploration.explorationUtils as eu

from neurolib.utils.loadData import Dataset

from Derivations import Derivations

#for the brainplot functions
from Brainplot import plot_brain
from Brainplot import plot_involvement_distribution
import xarray as xr
from neurolib.utils.signal import Signal
import matplotlib.pyplot as plt



In [None]:
ds = Dataset("gw", fcd=True)

In [None]:
wc = WCModel(Cmat = ds.Cmat, Dmat = ds.Dmat)

## Run()

In [None]:

fix = {'duration': 11*60.0*1000, 'sigma_ou': 0.287031, 'K_gl': 2.573845, 
      # 'adap_init': 0.0 * np.random.uniform(0.99, 1, (1, 1)), #inkl a_adap Null setzen für ausschalten der Adap
       'tau_exc': 2.5, 'tau_inh': 3.75, 
       'c_excexc': 16, 'c_excinh': 12, 'c_inhexc': 12, 'c_inhinh': 3, 
       'a_exc': 1.0, 'a_inh': 1.0, 
       'mu_exc': 5.0, 'mu_inh': 5.0,
       'a_adap': 98.381822, 'tau_adap': 3050.402224,
       'exc_ext': 4.960871, 'inh_ext': 2.668888}
wc.params.update(fix)


#4.960871	2.668888	2.573845	98.381822	3050.402224	0.287031	0.513186	0.364011	0.734326

In [None]:
dev = Derivations(model=wc, params=fix)

In [None]:
wc.run()

# Prepare the data

In [None]:
def filter_peaks(peaks, involvement, threshold_max, threshold_min = 0):
    around = int(500 / wc.params.dt)
    return [p for p in peaks if (involvement[p] > threshold_min and involvement[p] <= threshold_max)]

In [None]:
cut_off = int((fix['duration']/wc.params.dt) - 600000)
#print(cut_off)

x = wc.exc[:, -cut_off:]
x_adap = wc.adap[:, -cut_off:]

In [None]:
#oss = dev.checkMultiOsc(x)

#print('Oscillations were checked')bool(filter_on[k])
thresh = 0.2 * np.max(wc.output, axis=1)

states = [dev.getUpDownWindows(x[k], thresh[k], filter_long=True, dur=25) for k in range(len(x))]
states = np.array(states)
stateLengths = dev.getStateLengths(states)

print('State lengths are done.')

durations = dev.getDurationsNdim(stateLengths)

up_all = durations[0]
down_all = durations[1]

dt_to_sec = wc.params.dt / 1000
up_dur = [u * dt_to_sec for up in up_all for u in up]
down_dur = [d * dt_to_sec for down in down_all for d in down]

print('durations done')
up_dur_mean = [dev.getMean(np.array(array)) for array in up_all if array]
down_dur_mean = [dev.getMean(np.array(array)) for array in down_all if array]


unfiltered_involvement = dev.getInvolvement(states)
print('involvement is done')

In [None]:
involvement = scipy.ndimage.gaussian_filter1d(unfiltered_involvement, 2000)

#Da es bei den duration-derivations auf winzige Intervallunterschiede ankommt und wir bereits spindles rausfiltern,
#gehen wir hier über die ungefilterte involvement time series, für die Unterscheidung zwischen globalen und localen
#waves allerdings nicht.
up_bin_means, down_bin_means, bin_edges = dev.getStatesInvolvementDistribution(states, unfiltered_involvement, nbins=10)
print('up down involvement is done')

peaks = scipy.signal.find_peaks(involvement, height=0.1, distance=10000)[0]

peaks25 = filter_peaks(peaks, involvement, 0.50, 0.25)
peaks50 = filter_peaks(peaks, involvement, 0.75, 0.50)
peaks75 = filter_peaks(peaks, involvement, 1, 0.75)

global_iei = np.diff(np.sort(peaks50 + peaks75).tolist())/1000*wc.params.dt
local_iei = np.diff(peaks25)/1000*wc.params.dt

# Start the plotting

In [None]:
#Packages for plotting
from plotly.offline import init_notebook_mode, plot, iplot
from plotly.subplots import make_subplots
import plotly.offline as pyo

import chart_studio.plotly as py #chart_studio has to be installed: pip install chart_studio
import plotly.graph_objs as go
import plotly.figure_factory as ff
import plotly.express as px

In [None]:
from Templates import template
from Templates import brain_result_color_list
from Templates import brain_result_colors

In [None]:
colors = brain_result_colors

In [None]:
width = template.layout.width * (1/3)
height = template.layout.height * (2/3)

# 0. Frequency per node degree for best fit

In [None]:
from Topology import Topology
top = Topology(wc, ext_input=[2.4, 1.12], fix_params=fix, Cmat=ds.Cmat)
top.getDirectedNodeDegreeIN()
nd = top.directedNodeDegreesIN
#print('node number of lowest degree: ', nd[np.argmin(nd)])
#print('node number of highest degree: ', np.argmax(nd))

In [None]:
from scipy import signal
frequencies, psd =  signal.welch(x, 1000/wc.params.dt, 
                                 window='hanning', 
                                 nperseg=int(6 * 1000 / wc.params.dt) , 
                                 scaling='spectrum')

        
idx_dominant_frequ = np.argmax(psd, axis=1)

In [None]:
dom_frequs = frequencies[idx_dominant_frequ]
print('The dominant frequencies for the best fit per node: ', dom_frequs)
idx = np.argmax(np.sum(psd,axis=0))
f = frequencies[idx]
print('The dominant frequency over all nodes for the best fit is: ', f)

In [None]:
fig = go.Figure()


fig.add_trace(go.Scatter(x=nd, y=dom_frequs, mode='markers', 
                         marker=dict(color='black'),
                         name='DOWN'))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.2, y=1.1, text='(a)', font=dict(color='black')),
                      dict(x= 0.155, y=29.33, xref='x', yref='y',
                          showarrow=True, axref='x', ayref='y',
                          ax=0.6,ay=30,
                          text=f'last node of noise<br>dominated oscillations</br>node degree: 0.152<br>dom. frequency: 30.05',
                          font_size=18),
                      dict(x= 0.17, y=0.66, xref='x', yref='y',
                          showarrow=True, axref='x', ayref='y',
                          ax=0.4,ay=15,
                          text=f'first node of adaptation<br>dominated oscillations<br>node degree: 0.17<br>dom. frequency: 0.66',
                          font_size=18)
                  ],
                 width=width, height=height,
                # legend=dict(x=0.65, y=0.5),
                 xaxis=dict(title_text='Node degree'),
                 yaxis=dict(title_text='Frequency [Hz]', tickvals=[0,10,20,30]))

fig.show()

## 1. State Durations per Involvement

In [None]:
fig = go.Figure()

x1=bin_edges[:-1] * 100
x2=(bin_edges[:-1] + 0.05) * 100

fig.add_trace(go.Bar(x=x1, y=up_bin_means[::-1],
                    name='up',
                    marker=dict(line_width=0.5, color=colors['up_duration'])))
fig.add_trace(go.Bar(x=x2, y=down_bin_means,
                    name='down',
                    marker=dict(line_width=0.5, color=colors['down_duration'])))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.255,y=0.5,text='Duration [ms]', font_size=26,textangle=-90),
                      dict(x=-0.255,y=1.1, text='(b)')
                  ],
                  width=width, height=height,
                  xaxis=dict(title_text='Involvement [%]', range=[0,101]),
                  yaxis=dict(title_text='', tickvals=[0,400,800]),
                  legend=dict(x=0.01, y=1.03),
                 margin=dict(l=80))

fig.show()

In [None]:
down_bin_means

In [None]:
up_bin_means

## 2. Distribution of State-Durations

In [None]:
fig = go.Figure()

fig.add_trace(go.Histogram(x=up_dur, histnorm='percent', 
                           marker=dict(line_width=0.75, color=colors['up_duration']),
                           xbins_size=0.2,
                           name='up'))
fig.add_trace(go.Histogram(x=down_dur, histnorm='percent', 
                           marker=dict(line_width=0.75, color=colors['down_duration']),
                           xbins_size=0.2,
                           name='down'))


fig.update_layout(template=template, 
                  annotations=[
                      dict(x=-0.22,y=0.5,text='Log probability', font_size=26,textangle=-90),
                      dict(x=-0.21,y=1.1, text='(c)')
                  ],
                  width=width, height=height,
                  xaxis=dict(title_text='Duration [s]',tickvals=[0,1,2,3,4,5]),
                  yaxis=dict(title_text='', showticklabels=True,
                            tickvals=[0.1,10]),
                  barmode='overlay',
                 legend=dict(x=0.7,y=1.02))

fig.update_traces(opacity=0.9)

fig.update_yaxes(type='log')
fig.update_xaxes(ticks='outside', tick0=0)



fig.show()

In [None]:
fig = go.Figure()

time = np.linspace(0,20,200000)

for node in [23,25,68,77]:
    fig.add_trace(go.Scatter(x=time,y=x[node][-200000:]))
    
fig.update_layout(template=template, 
                  width=800, height=400)

fig.show()

## 3. Involvement in DOWN over time

In [None]:
#Plot involvement timeseries:
fig = go.Figure()

involvement_prozent = unfiltered_involvement * 100

fig.add_trace(go.Scatter(x=np.linspace(0,60,600000), y=involvement_prozent[-600000:],
                        marker=dict(color='salmon')))

fig.update_layout(template=template, 
                  annotations=[
                      dict(x=-0.15, y=1.32, text='(d)', font=dict(color='black')),
                      dict(x=-0.15, y=0.5, text='Involvement [%]', font_size=26, textangle=-90)
                  ],
                  width=template.layout.width*0.5, height=height*(2/3),
                  xaxis=dict(title_text='Time [s]', tickvals=[0,20,40,60]),
                  yaxis=dict(title_text='', tickvals=[0,50,100]),
                 margin=dict(l=82))

In [None]:
print("Mean involvement: ", np.mean(unfiltered_involvement))

In [None]:
print(np.sum(unfiltered_involvement<0.5)/len(unfiltered_involvement)*100, "% of slow oscillations were detected in less than 50% of regions")

## 4.1 Distribution of iei, distinguished by global vs. local

In [None]:
fig = go.Figure()


fig.add_trace(go.Histogram(x=local_iei, histnorm='percent',
                           xbins_size=0.25, 
                           marker=dict(color='gray', 
                                       line=dict(width=0.75)),
                           name='local'))
fig.add_trace(go.Histogram(x=global_iei, histnorm='percent', 
                           xbins_size=0.25,
                           marker=dict(color='green', 
                                       line=dict(width=0.75)),
                           name='global'))


fig.update_layout(template=template, 
                  annotations=[
                      dict(x=-0.2,y=1.1, text='(a)')
                  ],
                  width=width, height=height,
                  xaxis=dict(title_text='Inter-event interval [s]', range=[0,1], tickvals=[1,2,3,4,5,6,10], tickfont_size=18),
                  yaxis=dict(title_text='Fraction [%]', tickfont_size=18),#, showticklabels=False),
                  barmode='overlay',
                 legend=dict(x=0.7,y=1.02))

fig.update_traces(opacity=0.75)
fig.update_xaxes(type='log')

fig.show()

## 4.2 Involvement in down, global vs local

In [None]:
fig = go.Figure()

fig.add_trace(go.Histogram(x=unfiltered_involvement*100, histnorm='percent',
                         #  nbinsx=15,
                           xbins_size=5.0,
                           marker_color=['lightgray', 'lightgray', 'lightgray', 'lightgray', 'lightgray', 
                                        'gray', 'gray', 'gray', 'gray', 'gray', 
                                        'green', 'green', 'green', 'green', 'green', 
                                        'green', 'green', 'green', 'green', 'green', ],
                           marker_line_width=0.75))


fig.update_layout(template=template, 
                  annotations=[
                      dict(x=0.25, y=0.5, text='local', font_color='gray'),
                      dict(x=0.9, y=0.2, text='global', font_color='green'),
                      dict(x=-0.2,y=1.1, text='(c)')
                  ],
                  width=width, height=height,
                  xaxis=dict(title_text='Involvement [%]',
                             tickvals=[0, 50, 100], range=[0,101]),
                  yaxis=dict(title_text='Fraction [%]', showticklabels=True))


fig.show()

## 5. UP-DOWN States per Nodes over time (heatmap yellow-blue)

In [None]:
fig = go.Figure()

time = len(states[0]) / (1000/0.1) # number_of_integration_steps / (ms/integration_time_constant)
steps = len(states[0])
fig.add_trace(go.Heatmap(z=states[:,-600000:], x=np.linspace(0,time,steps)[-600000:],
                         colorscale=[[0.0, colors['DOWN']], [0.5, colors['DOWN']], 
                                     [0.5, colors['UP']], [1.0, colors['UP']]],
                         colorbar=dict(nticks=2,
                                       tickvals=[0.05, 0.95], 
                                       ticktext=['DOWN', 'UP'],
                                       tickangle=90)))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.085,y=1.1, text='(b)')
                  ],
                 width=800, height=400)
fig.update_xaxes(title_text='Time [s]', tickvals=[540,560,580,600], ticktext=['0','20','40','60'])
fig.update_yaxes(title_text='Node', tickvals=[0,19,39,59,79], ticktext=['1','20','40','60','80'])
fig.show()

In [None]:
fig.write_image(IMG_PATH + 'states_only_correctedNodes.png')

## Plot corresponding distribution of states over brain map

One Problem still: Im ALN sind die node_mean_down_phases alle negativ, hier nicht: WARUM???

In [None]:
t = wc.t[-cut_off:]

peaks = scipy.signal.find_peaks(unfiltered_involvement, height=0.97, distance=30000)[0]
#print(node_mean_phases_down)
dt_min = peaks[-4]
print(dt_min)

#deltas = np.linspace(-20000, 10000, 21)
deltas = [-4000, -3000, -2000, 0, 1000, 3000]
for delta in deltas:
    print(delta, np.sum(states[:, int(dt_min + delta)]))
    plot_brain(wc, ds, color=states[:, int(dt_min + delta)], size=np.multiply(800,nd), title=f"t = {int(delta*wc.params.dt)} ms", cmap='plasma', cbar=False, clim=[0, 1])
    #plt.savefig(f"/Users/caglar/Documents/PhD/papers/2020-1-evolutionary-fitting/figures/assets/sleep_model/assets/frames/frame_{delta}.pdf", transparent=True)
    
    plt.show()

In [None]:
peaks

# 6. Power Spectrum

In [None]:
import neurolib.utils.functions as func

model_frs, model_pwrs = func.getMeanPowerSpectrum(wc.exc, dt=wc.params.dt, maxfr=10, spectrum_windowsize=6)

maxfr = 10

model_frs_plot = model_frs[model_frs < maxfr]
model_pwrs_plot = model_pwrs[1:len(model_frs_plot)]

In [None]:
import scipy.signal as signal

fig = make_subplots(rows=1, cols=1, specs=[[{'secondary_y': True}]])

maxfr = 10
spectrum_windowsize=6

for act in x:
    frequ, power_spectral_density =  signal.welch(act, 10.0*1000, window='hanning', 
                                                  nperseg=int(spectrum_windowsize * 1000 / wc.params.dt) , 
                                                  scaling='spectrum')
    frequ = frequ[frequ < maxfr]
    power_spectral_density = power_spectral_density[1:len(frequ)]
    fig.add_trace(go.Scatter(x=frequ, y=power_spectral_density,
                            showlegend=False), row=1, col=1, secondary_y=False)

fig.add_trace(go.Scatter(x=model_frs_plot, y=model_pwrs_plot,
                         showlegend=False,
                         marker=dict(color='black'), line=dict(width=3)), row=1, col=1, secondary_y=True)

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.195, y=1.32, text='(e)', font=dict(color='black')),
                      dict(x=1.14, y=0.5, text='Mean PS', font=dict(size=26,color='black'), textangle=90),
                      dict(x=-0.2, y=0.5, text='PS [V**2]', font=dict(size=26,color='black'), textangle=-90),
                  ],
                 width=template.layout.width*0.5, height=height*(2/3),
                 yaxis=dict(type='log', title_text='', tickvals=[0,0.001,0.01]),
                 yaxis2=dict(type='log',title_text='', tickvals=[0,0.001,0.01]),
                 margin=dict(l=82,r=80))

fig.update_xaxes(title_text='Frequency [Hz]')

fig.show()

In [None]:
import dill
f_eeg, mean_eeg_power = dill.load(open("./data/mean_eeg_power_N3.dill", "rb"))

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=model_frs_plot, y=model_pwrs_plot,
                         showlegend=False,
                         marker=dict(color='black'), line=dict(width=3)), row=1, col=1, secondary_y=True)

# 7. Time Series

In [None]:
dt_min=peaks[-4]

eins = wc.exc[0][dt_min-4000:dt_min+3000]
zwei = wc.exc[22][dt_min-4000:dt_min+3000]
drei = wc.exc[78][dt_min-4000:dt_min+3000]

time=np.linspace(-400,300,7000)

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=time, y=eins,
                        mode='lines', name='Node #0',
                        showlegend=True,
                        marker=dict(color='black')))

fig.add_trace(go.Scatter(x=time, y=zwei,
                        mode='lines', name='Node #23',
                        showlegend=True,
                        marker=dict(color='gray')))

fig.add_trace(go.Scatter(x=time, y=drei,
                        mode='lines', name='Node #79',
                        showlegend=True,
                        marker=dict(color='green')))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.085, y=1.4, text='(d)', font=dict(color='black')),
                  ],
                 width=800, height=height*0.5, 
                 xaxis=dict(title_text='Time [s]'),
                 yaxis=dict(title_text='E(t)', 
                            range=[0,1], tickvals=[0,1], title_font_color='black'),
                 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

fig.show()


## 8. Time spent per node degree (out)

In [None]:
from Topology import Topology

In [None]:
top = Topology(wc, ext_input=[2.4, 1.12], fix_params=fix, Cmat=ds.Cmat)
top.getDirectedNodeDegreeIN()
nd = top.directedNodeDegreesIN

In [None]:
up_dur_mean = np.sum(states==1, axis=1)/states[0].size*100
down_dur_mean = np.sum(states==0, axis=1)/states[0].size*100

In [None]:
states[0].size

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=nd, y=up_dur_mean, mode='markers', 
                         marker=dict(color=colors['up_duration']),
                         name='up'))
fig.add_trace(go.Scatter(x=nd, y=down_dur_mean, mode='markers', 
                         marker=dict(color=colors['down_duration']),
                         name='down'))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.2,y=1.1, text='(a)')
                  ],
                 width=width, height=height,
                 legend=dict(x=0.65, y=0.5),
                 xaxis=dict(title_text='Node degree'),
                 yaxis=dict(title_text='Time spent [%]', tickvals=[0,20,40,60,80]))

fig.show()

In [None]:
type(np.array(up_dur_mean))

## Investigation of Bistability Regime

In [None]:
high = wc.exc[2][-50000:]
low = wc.exc[31][-50000:]

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=np.linspace(0,2,20000), y=low[-20000:],
                        mode='lines', name='Node #32',
                         showlegend=True,
                        marker=dict(color='black')))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.078, y=1.25, text='(b)', font=dict(color='black')),
                  ],
                 width=template.layout.width*(2/3), height=height*0.5, 
                 xaxis=dict(title_text='', tickvals=[0,1,2]),
                 yaxis=dict(title_text='E(t)', 
                            range=[0,1], tickvals=[0,1], title_font_color='black'),
                 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

fig.show()

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=np.linspace(0,5,50000), y=high,
                        mode='lines', name='Node #3',
                        showlegend=True,
                        marker=dict(color='black')))

fig.update_layout(template=template,
                  annotations=[
                      dict(x=-0.078, y=1.25, text='(c)', font=dict(color='black')),
                  ],
                 width=template.layout.width*(2/3), height=height*0.5, 
                 xaxis=dict(title_text='Time [s]'),
                 yaxis=dict(title_text='E(t)', 
                            range=[0,1], tickvals=[0,1], title_font_color='black'),
                 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

fig.show()