# Generating stability plots and frequency spectra heatmaps for the local model

### Import modules with respective paths

In [None]:
# this path append is for binder only
import sys
sys.path.append("../../")

#spectrome modules
# from spectrome.stability import GetPoles_Laplace
from spectrome.utils import functions, path
from spectrome.brain import Brain
from spectrome.stability import localstability
from spectrome.forward import ntf_local_inverselaplace
# from spectrome.forward import getmacropoles_findroot
# from spectrome.forward import getmacropoles_domainmatrix
from spectrome.forward import ntf_local_freqplot as nt


#generic modules
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import copy

new_brain = Brain.Brain()

hcp_dir = path.get_data_path() # connectome information is in /data/ dir
new_brain.add_connectome(hcp_dir) # Use default files in /data/

# Some re-ordering and normalizing (reduced):
new_brain.reorder_connectome(new_brain.connectome, new_brain.distance_matrix)
new_brain.bi_symmetric_c()
new_brain.reduce_extreme_dir()

print(new_brain.ntf_params)

In [None]:

plt.rcParams.update({
    "mathtext.fontset": "stix",
    "font.family": "STIXGeneral",
    "xtick.labelsize": 12,
    "ytick.labelsize": 12
})

fig, ax = plt.subplots(1, 3,figsize=(10, 3))

new_brain.ntf_params["gii"] = 0.5
new_brain.ntf_params["gei"] = 0.2
roots, _ = localstability.local_stability(new_brain.ntf_params)
ind = roots.real>0
stplt = ax[0].scatter(roots.real[~ind],roots.imag[~ind],c='blue')
unstplt = ax[0].scatter(roots.real[ind],roots.imag[ind],c='orangered', marker='*',s=80)
ax[0].grid(True)
ax[0].set_title(r'$g_{ii}=0.5$, $g_{ei}=0.2$', fontsize=17)

new_brain.ntf_params["gii"] = 0.5
new_brain.ntf_params["gei"] = 0.5
roots, _ = localstability.local_stability(new_brain.ntf_params)
ind = roots.real>0
ax[1].scatter(roots.real[~ind],roots.imag[~ind],c='blue')
ax[1].scatter(roots.real[ind],roots.imag[ind],c='orangered', marker='*',s=80)
ax[1].grid(True)
ax[1].set_title(r'$g_{ii}=0.5$, $g_{ei}=0.5$', fontsize=17)


new_brain.ntf_params["gii"] = 0.5
new_brain.ntf_params["gei"] = 0.8
roots, _ = localstability.local_stability(new_brain.ntf_params)
ind = roots.real>0
ax[2].scatter(roots.real[~ind],roots.imag[~ind],c='blue')
ax[2].scatter(roots.real[ind],roots.imag[ind],c='orangered', marker='*',s=80)
ax[2].grid(True)
ax[2].set_title(r'$g_{ii}=0.5$, $g_{ei}=0.8$', fontsize=17)

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel('Real',fontsize=18)
plt.ylabel('Imaginary',labelpad=12,fontsize=18)

plt.tight_layout()


## Plot frequency spectra for local model

In [None]:
def local_signal(x):
    fmin = 1 # 1Hz - 40Hz signal range, filter for this with hbp
    fmax = 40
    fs = 600
    fvec = np.linspace(fmin,fmax,fmax)
    htotal = np.empty(fvec.shape, dtype=complex)

    for i in range(fmax):
        ht = nt.ntf_local(x,fvec[i])
        htotal[i] = ht[0]
    
    lpf = np.array([1, 2, 5, 2, 1])
    lpf = lpf/np.sum(lpf)

    spectrum = (np.abs(htotal))**2
    filtered = functions.mag2db(np.convolve(spectrum, lpf, 'same'))/2
    
    return filtered


fmin = 1 # 1Hz - 40Hz signal range, filter for this with hbp
fmax = 40
fvec = np.linspace(fmin,fmax,fmax)

# Bounds to ensure stable regime
par = [0.25, 1.5, 0.01, 0.005]


gei = np.linspace(0.01,0.3,50)
gii = np.linspace(1,1.7,50)
taue = np.linspace(0.005,0.04,50)
taui = np.linspace(0.005,0.04,50)

freqres = np.zeros((fmax,50))

par = [0.25, 1.5, 0.01, 0.02]

sig = local_signal(par)

plt.plot(fvec,sig)

## Frequency spectra heatmaps

In [None]:
par_default = [0.25, 1.5, 0.01, 0.005]
fig, ax = plt.subplots(2, 2,figsize=(8, 6))
cbar_ax = fig.add_axes([.9, .3, .01, .4])

par = copy.deepcopy(par_default)
for i in range(50):
    par[0] = gei[i]
    freqres[:,i] = local_signal(par)

sns.heatmap(freqres,vmin=-50,vmax=-20,cmap="mako",ax=ax[0,0],cbar_ax=cbar_ax)
ax[0][0].set_yticks(np.round(np.linspace(0,40,5)))
ax[0][0].set_yticklabels(np.int_(np.linspace(1,40,5)))
ax[0][0].set_xticks(np.linspace(0,50,6))
ax[0][0].set_xticklabels(np.round(np.linspace(0.01,0.3,6),decimals=2))
ax[0][0].set_xlabel(r'$g_{ei}$',fontsize=17)
ax[0][0].invert_yaxis()

par = copy.deepcopy(par_default)
for i in range(50):
    par[1] = gii[i]
    freqres[:,i] = local_signal(par)

sns.heatmap(freqres,vmin=-50,vmax=-20,cmap="mako",ax=ax[0,1],cbar=0,cbar_ax=None)
ax[0][1].set_yticks(np.round(np.linspace(0,40,5)))
ax[0][1].set_yticklabels(np.int_(np.linspace(1,40,5)))
ax[0][1].set_xticks(np.linspace(0,50,6))
ax[0][1].set_xticklabels(np.round(np.linspace(1,1.7,6),decimals=2))
ax[0][1].set_xlabel(r'$g_{ii}$',fontsize=17)
ax[0][1].invert_yaxis()

par = copy.deepcopy(par_default)
for i in range(50):
    par[2] = taue[i]
    freqres[:,i] = local_signal(par)

sns.heatmap(freqres,vmin=-50,vmax=-20,cmap="mako",ax=ax[1,0],cbar=0,cbar_ax=None)
ax[1][0].set_yticks(np.round(np.linspace(0,40,5)))
ax[1][0].set_yticklabels(np.int_(np.linspace(1,40,5)))
ax[1][0].set_xticks(np.linspace(0,50,6))
ax[1][0].set_xticklabels(np.round(np.linspace(0.005,0.04,6),decimals=3))
ax[1][0].set_xlabel(r'$\tau_e$ (s)',fontsize=17)
ax[1][0].invert_yaxis()

par = copy.deepcopy(par_default)
for i in range(50):
    par[3] = taui[i]
    freqres[:,i] = local_signal(par)

sns.heatmap(freqres,vmin=-50,vmax=-20,cmap="mako",ax=ax[1,1],cbar=0,cbar_ax=None)
ax[1][1].set_yticks(np.round(np.linspace(0,40,5)))
ax[1][1].set_yticklabels(np.int_(np.linspace(1,40,5)))
ax[1][1].set_xticks(np.linspace(0,50,6))
ax[1][1].set_xticklabels(np.round(np.linspace(0.005,0.04,6),decimals=3))
ax[1][1].set_xlabel(r'$\tau_i$ (s)',fontsize=17)
ax[1][1].invert_yaxis()

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.ylabel('Frequency (Hz)',labelpad=12,fontsize=17)

fig.tight_layout(rect=[0, 0, .9, 1])

## Plot of primary peak

In [None]:
par_default = [0.25, 1.5, 0.01, 0.005]

freqpeak = np.zeros((len(taue),len(taui)))
for i in range(len(taue)):
    for j in range(len(taui)):
        par = copy.deepcopy(par_default)
        par[2] = taue[i]
        par[3] = taui[j]
        freqpeak[i,j] = fvec[np.argmax(local_signal(par))]

ax = sns.heatmap(freqpeak,cmap="mako")
plt.xticks(np.linspace(0,50,6),np.round(np.linspace(0.005,0.04,6),decimals=3))
plt.yticks(np.linspace(6,50,5),np.round(np.linspace(taue[6],0.04,5),decimals=3))
plt.xlabel(r'$\tau_i$ (s)',fontsize=17)
plt.ylabel(r'$\tau_e$ (s)',fontsize=17)
ax.invert_yaxis()