# This notebook generates plots of dynamics of model parameters as shown in section 2.4 and 2.5 of the paper

In [None]:
import sys
sys.path.append("../../")

from spectrome.utils import functions, path
import numpy as np
import xarray as xr
from pathlib import Path
import os
import pickle as pkl

from spectrome.brain import Brain
from spectrome.forward import runforward

import matplotlib.pyplot as plt

plt.rcParams.update({
    "mathtext.fontset": "stix",
    "font.family": "STIXGeneral",
    "xtick.labelsize": 12,
    "ytick.labelsize": 12
})

In [None]:
fz = 16
c = '#0571b0'

t = np.linspace(0,60,36000)
tpoints = np.arange(0,36000,600,dtype=int)
trange = 5*np.arange(12)
timepoints = t[tpoints]
tpoints2 = timepoints[trange]
print(tpoints2)

# define data directory
data_dir = path.get_data_path()

# cortical areas with MEG collected + source reconstructed
rois_with_MEG = np.arange(0,68)

## Load MEG:
ind_psd_xr = xr.open_dataarray(data_dir + '/individual_wavelet_reordered_smooth.nc')
ind_psd = ind_psd_xr.values

fvec = ind_psd_xr["frequencies"].values
timepoints = ind_psd_xr["timepoints"].values

paramlist = range(12) 
paramlist2 = 5*np.arange(12)

ind_param_static = np.loadtxt("../results/MSGM_Reordered_matlab_500iter_unstablegii.csv", delimiter=",")


### Generate distributions for all subjects

In [None]:
subsdir = Path("../results/stablealpha")
sub_arr = os.listdir(subsdir)

fidl = []

for i in range(36):
    fname = str(os.path.basename(sub_arr[i]))
    fid = int(fname.split('.p')[0])
    fidl.append(fid)
        
ind = np.argsort(fidl)

sub_arr_sorted = np.take(sub_arr,ind)

alphal = np.empty((36,12))
geil = np.empty((36,12))
giil = np.empty((36,12))
pearsonrl = np.empty((36,12))

fig, ax = plt.subplots(1,4, figsize=(15,4))

for i in range(36):
    with open(subsdir/sub_arr_sorted[i], 'rb') as f:
        x = np.array(pkl.load(f))
        ax[0].plot(tpoints2,x[:,0])
        ax[1].plot(tpoints2,x[:,1])
        ax[2].plot(tpoints2,x[:,2])
        ax[3].plot(tpoints2,x[:,3])
        alphal[i,:] = x[:,0]
        geil[i,:] = x[:,1]
        giil[i,:] = x[:,2]
        pearsonrl[i,:] = x[:,3]


ax[0].set_ylabel(r'$\alpha$', fontsize = fz)
ax[1].set_ylabel(r'$g_{ei}$', fontsize = fz)
ax[2].set_ylabel(r'$g_{ii}$', fontsize = fz)
ax[3].set_ylabel(r'Pearsons r', fontsize = fz)

plt.tight_layout()


In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,8))

cmap = 'plasma'

c = ax[0][0].pcolormesh(tpoints2, range(36), alphal, cmap=cmap, shading='auto')
cbar = ax[0][0].figure.colorbar(c, ax=ax[0][0]);
cbar.ax.set_ylabel(r'$\alpha$', fontsize=16);
ax[0][0].invert_yaxis()

c = ax[0][1].pcolormesh(tpoints2, range(36), geil, cmap=cmap, shading='auto')
cbar = ax[0][1].figure.colorbar(c, ax=ax[0][1]);
cbar.ax.set_ylabel(r'$g_{ei}$', fontsize=16);
ax[0][1].invert_yaxis()

c = ax[1][0].pcolormesh(tpoints2, range(36), giil, cmap=cmap, shading='auto')
cbar = ax[1][0].figure.colorbar(c, ax=ax[1][0]);
cbar.ax.set_ylabel(r'$g_{ii}$', fontsize=16);
ax[1][0].invert_yaxis()

c = ax[1][1].pcolormesh(tpoints2, range(36), pearsonrl, cmap=cmap, shading='auto')
cbar = ax[1][0].figure.colorbar(c, ax=ax[1][1]);
cbar.ax.set_ylabel(r'Pearsons r', fontsize=16);
ax[1][1].invert_yaxis()

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel('Time (s)',fontsize=17)
plt.ylabel('Subject ID',fontsize=17)
plt.tight_layout()


In [None]:
subsdir2 = Path("../results/unstablealpha")
sub_arr2 = os.listdir(subsdir2)

fidl2 = []

for i in range(36):
    fname = str(os.path.basename(sub_arr2[i]))
    fid = int(fname.split('.p')[0])
    fidl2.append(fid)
        
ind = np.argsort(fidl2)

sub_arr_sorted2 = np.take(sub_arr2,ind)


alphal2 = np.empty((36,12))
geil2 = np.empty((36,12))
giil2 = np.empty((36,12))
pearsonrl2 = np.empty((36,12))

fig, ax = plt.subplots(1,4, figsize=(15,4))

for i in range(36):
    with open(subsdir2/sub_arr_sorted2[i], 'rb') as f:
        x = np.array(pkl.load(f))
        ax[0].plot(tpoints2,x[:,0])
        ax[1].plot(tpoints2,x[:,1])
        ax[2].plot(tpoints2,x[:,2])
        ax[3].plot(tpoints2,x[:,3])
        alphal2[i,:] = x[:,0]
        geil2[i,:] = x[:,1]
        giil2[i,:] = x[:,2]
        pearsonrl2[i,:] = x[:,3]
        

ax[0].set_ylabel(r'$\alpha$', fontsize = fz)
ax[1].set_ylabel(r'$g_{ei}$', fontsize = fz)
ax[2].set_ylabel(r'$g_{ii}$', fontsize = fz)
ax[3].set_ylabel(r'Pearsons r', fontsize = fz)

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,8))

cmap = 'plasma'

c = ax[0][0].pcolormesh(tpoints2, range(36), alphal2, cmap=cmap, shading='auto')
cbar = ax[0][0].figure.colorbar(c, ax=ax[0][0]);
cbar.ax.set_ylabel(r'$\alpha$', fontsize=16);
ax[0][0].invert_yaxis()

c = ax[0][1].pcolormesh(tpoints2, range(36), geil2, cmap=cmap, shading='auto')
cbar = ax[0][1].figure.colorbar(c, ax=ax[0][1]);
cbar.ax.set_ylabel(r'$g_{ei}$', fontsize=16);
ax[0][1].invert_yaxis()

c = ax[1][0].pcolormesh(tpoints2, range(36), giil2, cmap=cmap, shading='auto')
cbar = ax[1][0].figure.colorbar(c, ax=ax[1][0]);
cbar.ax.set_ylabel(r'$g_{ii}$', fontsize=16);
ax[1][0].invert_yaxis()

c = ax[1][1].pcolormesh(tpoints2, range(36), pearsonrl2, cmap=cmap, shading='auto')
cbar = ax[1][0].figure.colorbar(c, ax=ax[1][1]);
cbar.ax.set_ylabel(r'Pearsons r', fontsize=16);
ax[1][1].invert_yaxis()

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel('Time (s)',fontsize=17)
plt.ylabel('Subject ID',fontsize=17)
plt.tight_layout()

## Capture number of switches in alpha

In [None]:
n_switches = np.empty((36,1))

fig,ax = plt.subplots()

for j in range(36):
    with open(subsdir/sub_arr_sorted[j], 'rb') as f:
        x = np.array(pkl.load(f))
        alpha = x[:,0]
        count = 0
        thr = 0.5
        for i in range(len(alpha)-1):
            diff = alpha[i+1] - alpha[i]
            if diff >= thr:
                count += 1
        n_switches[j] = count

ax.set_yticks([0,1,2,3])
ax.set_yticklabels([0,1,2,3])
plt.scatter(range(36),n_switches)     
plt.xlabel(r'Subject ID', fontsize=17)
plt.ylabel(r'Number of switches in $\alpha$', fontsize=17)
plt.tight_layout()

print(len(np.where(n_switches>0)[0]))

In [None]:
n_switches = np.empty((36,1))

fig,ax = plt.subplots()

for j in range(36):
    with open(subsdir2/sub_arr_sorted2[j], 'rb') as f:
        x = np.array(pkl.load(f))
        alpha = x[:,0]
        count = 0
        thr = 0.5
        for i in range(len(alpha)-1):
            diff = alpha[i+1] - alpha[i]
            if diff >= thr:
                count += 1
        n_switches[j] = count

ax.set_yticks([0,1,2,3])
ax.set_yticklabels([0,1,2,3])
plt.scatter(range(36),n_switches)   
plt.xlabel(r'Subject ID', fontsize=17)
plt.ylabel(r'Number of switches in $\alpha$', fontsize=17)
plt.tight_layout()

print(len(np.where(n_switches>0)[0]))

## Get local stability of all subjects

In [None]:
from spectrome.stability import localstability

brain = Brain.Brain()

def check_localstability(brain,subid):
    local_st = []
    macro_st = []
    total_st = []
    with open(subsdir/sub_arr_sorted[subid], 'rb') as f:
        x = np.array(pkl.load(f))
    for i in paramlist:
        brain.ntf_params["tau_e"] = ind_param_static[subid,0]/1000
        brain.ntf_params["tau_i"] = ind_param_static[subid,1]/1000
        brain.ntf_params["alpha"] = x[i,0]
        brain.ntf_params["speed"] = ind_param_static[subid,3]
        brain.ntf_params["gei"] = x[i,1]
        brain.ntf_params["gii"] = x[i,2]
        brain.ntf_params["tauC"] = ind_param_static[subid,6]/1000
        _, st = localstability.local_stability(brain.ntf_params)
        local_st.append(st)
        if x[i,0] == 1:
            macro_st.append(2)
        else:
            macro_st.append(0)
        if st == 2 and macro_st[i] == 2:
            total_st.append(4)
        elif st == 2 and macro_st[i] < 2:
            total_st.append(3)
        elif st < 2 and macro_st[i] == 2:
            total_st.append(2)
        elif st < 2 and macro_st[i] < 2:
            total_st.append(1)
    return local_st, macro_st, total_st

In [None]:
localstl = np.empty((36,12))
macrostl = np.empty((36,12))
totalstl = np.empty((36,12))

fig, ax = plt.subplots(1,3, figsize=(12,4))
for i in range(36):
    localst, macrost, totalst = check_localstability(brain,i)
    ax[0].plot(tpoints2,localst)
    ax[1].plot(tpoints2,macrost)
    ax[2].plot(tpoints2,totalst)
    localstl[i,:] = localst
    macrostl[i,:] = macrost
    totalstl[i,:] = totalst
    
ax[0].set_yticks([0,2])
ax[0].set_yticklabels(['stable','unstable'])
ax[0].set_title('Local stability')

ax[1].set_yticks([0,2])
ax[1].set_yticklabels(['stable','alpha=1'])
ax[1].set_title('Macro stability')

ax[2].set_yticks([1,2,3,4])
ax[2].set_yticklabels(['both stable', 'alpha=1', 'local unstable', 'both unstable'])
ax[2].set_title('Total stability')

plt.tight_layout()
# plt.xlabel('time (s)')

In [None]:
fig, ax = plt.subplots(1,3, figsize=(15,5))
from matplotlib import cm
cmap1 = cm.get_cmap("plasma", 2)
cmap2 = cm.get_cmap("plasma", 4)

c = ax[0].pcolormesh(tpoints2, range(36), localstl, cmap=cmap1, shading='auto')
cbar = ax[0].figure.colorbar(c, ax=ax[0], ticks = [0.5,1.5]);
cbar.ax.set_yticklabels(['Local stable', 'Local unstable'], rotation=90, fontsize=15, va='center')
ax[0].invert_yaxis()

c = ax[1].pcolormesh(tpoints2, range(36), macrostl, cmap=cmap1, shading='auto')
cbar = ax[1].figure.colorbar(c, ax=ax[1], ticks = [0.5,1.5]);
cbar.ax.set_yticklabels([r'$\alpha<1$', r'$\alpha=1$'], rotation=90, fontsize=15, va='center')
ax[1].invert_yaxis()

c = ax[2].pcolormesh(tpoints2, range(36), totalstl, cmap=cmap2, shading='auto')
cbar = ax[2].figure.colorbar(c, ax=ax[2], ticks = [1.4,2.1,2.85,3.6]);
cbar.ax.set_yticklabels([r'I', r'II', r'III', r'IV'], rotation=90, fontsize=15, va='center') 
ax[2].invert_yaxis()

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel('Time (s)',fontsize=17)
plt.ylabel('Subject ID',fontsize=17)
plt.tight_layout()


In [None]:
def check_localstability(brain,subid):
    local_st = []
    macro_st = []
    total_st = []
    with open(subsdir2/sub_arr_sorted2[subid], 'rb') as f:
        x = np.array(pkl.load(f))
    for i in paramlist:
        brain.ntf_params["tau_e"] = ind_param_static[subid,0]/1000
        brain.ntf_params["tau_i"] = ind_param_static[subid,1]/1000
        brain.ntf_params["alpha"] = x[i,0]
        brain.ntf_params["speed"] = ind_param_static[subid,3]
        brain.ntf_params["gei"] = x[i,1]
        brain.ntf_params["gii"] = x[i,2]
        brain.ntf_params["tauC"] = ind_param_static[subid,6]/1000
        _, st = localstability.local_stability(brain.ntf_params)
        local_st.append(st)
        if x[i,0] >= 1:
            macro_st.append(2)
        else:
            macro_st.append(0)
        if st == 2 and macro_st[i] == 2:
            total_st.append(4)
        elif st == 2 and macro_st[i] < 2:
            total_st.append(3)
        elif st < 2 and macro_st[i] == 2:
            total_st.append(2)
        elif st < 2 and macro_st[i] < 2:
            total_st.append(1)
    return local_st, macro_st, total_st


fig, ax = plt.subplots(1,3, figsize=(12,4))
for i in range(36):
    localst, macrost, totalst = check_localstability(brain,i)
    ax[0].plot(tpoints2,localst)
    ax[1].plot(tpoints2,macrost)
    ax[2].plot(tpoints2,totalst)
    
ax[0].set_yticks([0,2])
ax[0].set_yticklabels(['stable','unstable'])
ax[0].set_title('Local stability')

ax[1].set_yticks([0,2])
ax[1].set_yticklabels(['stable','alpha>=1'])
ax[1].set_title('Macro stability')

ax[2].set_yticks([1,2,3,4])
ax[2].set_yticklabels(['both stable', 'alpha>=1', 'local unstable', 'both unstable'])
ax[2].set_title('Total stability')

plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1,3, figsize=(15,5))
from matplotlib import cm
cmap1 = cm.get_cmap("plasma", 2)
cmap2 = cm.get_cmap("plasma", 4)

c = ax[0].pcolormesh(tpoints2, range(36), localstl, cmap=cmap1, shading='auto')
cbar = ax[0].figure.colorbar(c, ax=ax[0], ticks = [0.5,1.5]);
cbar.ax.set_yticklabels(['Local stable', 'Local unstable'], rotation=90, fontsize=15, va='center')
ax[0].invert_yaxis()

c = ax[1].pcolormesh(tpoints2, range(36), macrostl, cmap=cmap1, shading='auto')
cbar = ax[1].figure.colorbar(c, ax=ax[1], ticks = [0.5,1.5]);
cbar.ax.set_yticklabels([r'$\alpha<1$', r'$\alpha \geq 1$'], rotation=90, fontsize=15, va='center')
ax[1].invert_yaxis()

c = ax[2].pcolormesh(tpoints2, range(36), totalstl, cmap=cmap2, shading='auto')
cbar = ax[2].figure.colorbar(c, ax=ax[2], ticks = [1.4,2.1,2.85,3.6]);
cbar.ax.set_yticklabels([r'I', r'II', r'III', r'IV'], rotation=90, fontsize=15, va='center') 
ax[2].invert_yaxis()

fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel('Time (s)',fontsize=17)
plt.ylabel('Subject ID',fontsize=17)
plt.tight_layout()
