# E-PV-SST-NDNF perturbation flow map

## Symbolic Calculations

In [54]:
import sympy as sym
# Analytically calculate matrices before VIP modulation
tauEs, tauPs, tauSs, tauIs, bEs, bPs, bSs, bIs, wEEs, wEPs, wESs, wEIs, wPEs, wPPs, wPSs, wPIs, wSEs, wSPs, wSSs, wSIs, wIEs, wIPs, wISs, wIIs, wEI_correction_s = sym.symbols('self.tauE self.tauP self.tauS self.tauI bE bP bS bI self.wEE self.wEP self.wES self.wEI self.wPE self.wPP self.wPS self.wPI self.wSE self.wSP self.wSS self.wSI self.wIE self.wIP self.wIS self.wII wEI_correction')

Bs = sym.Matrix([
    [bEs, 0, 0, 0],
    [0, bPs, 0, 0],
    [0, 0, bSs, 0],
    [0, 0, 0, bIs]
])

invBs = sym.Matrix([
    [1/bEs, 0, 0, 0],
    [0, 1/bPs, 0, 0],
    [0, 0, 1/bSs, 0],
    [0, 0, 0, 1/bIs]
])

Ts = sym.Matrix([
    [tauEs, 0, 0, 0],
    [0, tauPs, 0, 0],
    [0, 0, tauSs, 0],
    [0, 0, 0, tauIs]
])

invTs = sym.Matrix([
    [1/tauEs, 0, 0, 0],
    [0, 1/tauPs, 0, 0],
    [0, 0, 1/tauSs, 0],
    [0, 0, 0, 1/tauIs]
])

Es = sym.eye(4)

# Ws = sym.Matrix([
#     [wEEs, -wEPs, -wESs, -wEIs],
#     [wPEs, -wPPs, -wPSs, -wPIs],
#     [wSEs, -wSPs, -wSSs, -wSIs],
#     [wIEs, -wIPs, -wISs, -wIIs]
# ])

Ws = sym.Matrix([
    [wEEs, -wEPs, -wESs, -wEIs],
    [wPEs, -wPPs, -wPSs, -0],
    [wSEs, -0, -0, -0],
    [0, -0, -wISs, -0]
])

Tr_s = sym.Matrix([
    [0, 0, 0, wEI_correction_s],
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 0, 0]
])

Js = invTs * (-Es + Bs * (Ws + Tr_s))
dets = sym.simplify((-invBs * Ts * Js).det())
adjs = (-invBs * Ts * Js).adjugate()

adj_EEs = sym.simplify(adjs[0,0])
adj_EPs = sym.simplify(adjs[0,1])
adj_EIs = sym.simplify(adjs[0,3])
adj_PIs = sym.simplify(adjs[1,3])
adj_SIs = sym.simplify(adjs[2,3])
adj_IIs = sym.simplify(adjs[3,3])

print(
    # f'J_num = {Js}\n\n'
    # f'det_num = {dets}\n\n'
    # f'L_EE = ({adj_EEs}) / det_num\n'
    # f'L_EP = ({adj_EPs}) / det_num\n'
    f'L_EI = ({adj_EIs}) / det_num\n'
    # f'L_PI = ({adj_PIs}) / det_num\n'
    # f'L_SI = ({adj_SIs}) / det_num\n'
    # f'L_II = ({adj_IIs}) / det_num'
    )

L_EI = ((-bP*self.wEI*self.wPP + bP*self.wPP*wEI_correction - self.wEI + wEI_correction)/(bP*bS)) / det_num



In [56]:
wES_0, alfa, rI, rS, bP, bS, wEI, wPP, det_num = sym.symbols('wES_0 alfa rI rS bP bS wEI wPP det_num')
pESN = (wES_0 / alfa) * sym.exp(-rI/alfa) # pESN=dW_13/drN, where W_13=-wES
wEI_correction = pESN * rS
L_EI = sym.simplify(((-bP*wEI*wPP + bP*wPP*wEI_correction - wEI + wEI_correction)/(bP*bS)) / det_num)

print(
    f'L_EI = {L_EI}'
)

L_EI = (-alfa*wEI*(bP*wPP + 1)*exp(rI/alfa) + bP*rS*wES_0*wPP + rS*wES_0)*exp(-rI/alfa)/(alfa*bP*bS*det_num)


## Working Code

In [1]:
import numpy as np
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from pathlib import Path
from scipy.linalg import eig
from typing import Tuple
import warnings
warnings.filterwarnings('ignore')

class ExtendedCircuit:
    def __init__(
            self,
            interneuron_name,
            W,
            tau=np.array([10,10,10,10]),
            rS0=2,
            rI0=2,
            step_rX=0.01,
            I_stim_E=0.3,
            I_stim_P=0.3,
            I_mod_I=0.3,
            power=2,
            mult_f=1/4
    ):
        # Define parameters
        self.interneuron_name = interneuron_name
        self.W = W

        self.wEE, self.wEP, self.wES_0, self.wEI, self.wPE, self.wPP, self.wPS, self.wPI, self.wSE, self.wSP, self.wSS, self.wSI, self.wIE, self.wIP, self.wIS, self.wII = W.flatten()
        self.wES = self.wES_0

        self.tauE, self.tauP, self.tauS, self.tauI = tau 

        self.rS0 = rS0  # SST initial firing rate
        self.rI0 = rI0  # Interneuron initial firing rate
        self.step_rX = step_rX

        self.I_stim_E = I_stim_E
        self.I_stim_P = I_stim_P
        self.I_mod_I = I_mod_I

        # I/O function parameters
        self.power = power
        self.mult_f = mult_f

        # Define colormap
        self.hex_colors = ['#ffffff', '#fcbba1', '#fc9272', '#fb6a4a', '#ef3b2c', '#cb181d', '#99000d']
        self.colors = [[int(hex_color[i:i+2], 16)/255 for i in (1, 3, 5)] for hex_color in self.hex_colors]

        # Create custom colormap
        self.cmap = LinearSegmentedColormap.from_list('custom', self.colors, N=128)

##################

    def calculate_linear(
            self,
            max_rEP,
            alfa
    ):
        
        """Calculate linearized neural network dynamics for given parameters"""
        
        self.rE_vec = np.round(np.arange(max_rEP, 0.5 - self.step_rX, -self.step_rX), 2)
        self.rP_vec = np.round(np.arange(0.5, max_rEP + self.step_rX, self.step_rX), 2)

        f_gain_num = np.zeros((len(self.rE_vec), len(self.rP_vec)))
        f_gain_num_mod = np.zeros_like(f_gain_num)
        f_maxEVs_num = np.zeros_like(f_gain_num)
        f_maxEVs_num_mod = np.zeros_like(f_gain_num)
        f_maxImEVs_num = np.zeros_like(f_gain_num)
        f_oscMetric_num = np.zeros_like(f_gain_num)
        f_oscMetric_num_mod = np.zeros_like(f_gain_num)
        f_modI_rE_num = np.zeros_like(f_gain_num)
        f_modI_rP_num = np.zeros_like(f_gain_num)
        f_modI_rS_num = np.zeros_like(f_gain_num)
        f_modI_rI_num = np.zeros_like(f_gain_num)

        for i, rE in enumerate(self.rE_vec):
            for j, rP in enumerate(self.rP_vec):
                rS = self.rS0
                rI = self.rI0

    ############# Calculate values before Interneuron modulation
                # Calculate mathematical correction to W due to perturbation and weight-rate dependency of wES
                if self.interneuron_name == 'ndnf':
                    self.wES = self.wES_0 * np.exp(-rI/alfa)
                    pESN = (self.wES_0 / alfa) * np.exp(-rI/alfa) # pESN=dW_13/drN, where W_13=-wES
                    wEI_correction = pESN * rS
                    if wEI_correction > self.wEI:  # we want to assume the sign of W*_14+correction, where W*_14=-wEI
                        wEI_correction = self.wEI
                elif self.interneuron_name == 'vip':
                    wEI_correction = 0
                else: print('False name of the interneuron. Avaliable options: "vip" or "ndnf".')
                
                # Calculate inputs
                xE = (rE/self.mult_f)**(1/self.power) - (self.wEE*rE - self.wEP*rP - self.wES*rS - self.wEI*rI)
                xP = (rP/self.mult_f)**(1/self.power) - (self.wPE*rE - self.wPP*rP - self.wPS*rS - self.wPI*rI)
                xS = (rS/self.mult_f)**(1/self.power) - (self.wSE*rE - self.wSP*rP - self.wSS*rS - self.wSI*rI)
                xI = (rI/self.mult_f)**(1/self.power) - (self.wIE*rE - self.wIP*rP - self.wIS*rS - self.wII*rI)

                # Calculate derivatives (gains)
                bE = self.power * self.mult_f * (self.wEE*rE - self.wEP*rP - self.wES*rS - self.wEI*rI + xE)**(self.power-1)
                bP = self.power * self.mult_f * (self.wPE*rE - self.wPP*rP - self.wPS*rS - self.wPI*rI + xP)**(self.power-1)
                bS = self.power * self.mult_f * (self.wSE*rE - self.wSP*rP - self.wSS*rS - self.wSI*rI + xS)**(self.power-1)
                bI = self.power * self.mult_f * (self.wIE*rE - self.wIP*rP - self.wIS*rS - self.wII*rI + xI)**(self.power-1)

                # Response matrix calculation
                det_num = -self.wEE*self.wII*self.wPP*self.wSS + self.wEE*self.wII*self.wPS*self.wSP + self.wEE*self.wIP*self.wPI*self.wSS - self.wEE*self.wIP*self.wPS*self.wSI - self.wEE*self.wIS*self.wPI*self.wSP + self.wEE*self.wIS*self.wPP*self.wSI + self.wEI*self.wIE*self.wPP*self.wSS - self.wEI*self.wIE*self.wPS*self.wSP - self.wEI*self.wIP*self.wPE*self.wSS + self.wEI*self.wIP*self.wPS*self.wSE + self.wEI*self.wIS*self.wPE*self.wSP - self.wEI*self.wIS*self.wPP*self.wSE - self.wEP*self.wIE*self.wPI*self.wSS + self.wEP*self.wIE*self.wPS*self.wSI + self.wEP*self.wII*self.wPE*self.wSS - self.wEP*self.wII*self.wPS*self.wSE - self.wEP*self.wIS*self.wPE*self.wSI + self.wEP*self.wIS*self.wPI*self.wSE + self.wES*self.wIE*self.wPI*self.wSP - self.wES*self.wIE*self.wPP*self.wSI - self.wES*self.wII*self.wPE*self.wSP + self.wES*self.wII*self.wPP*self.wSE + self.wES*self.wIP*self.wPE*self.wSI - self.wES*self.wIP*self.wPI*self.wSE - self.wIE*self.wPP*self.wSS*wEI_correction + self.wIE*self.wPS*self.wSP*wEI_correction + self.wIP*self.wPE*self.wSS*wEI_correction - self.wIP*self.wPS*self.wSE*wEI_correction - self.wIS*self.wPE*self.wSP*wEI_correction + self.wIS*self.wPP*self.wSE*wEI_correction - self.wEE*self.wII*self.wPP/bS + self.wEE*self.wIP*self.wPI/bS + self.wEI*self.wIE*self.wPP/bS - self.wEI*self.wIP*self.wPE/bS - self.wEP*self.wIE*self.wPI/bS + self.wEP*self.wII*self.wPE/bS - self.wIE*self.wPP*wEI_correction/bS + self.wIP*self.wPE*wEI_correction/bS - self.wEE*self.wII*self.wSS/bP + self.wEE*self.wIS*self.wSI/bP + self.wEI*self.wIE*self.wSS/bP - self.wEI*self.wIS*self.wSE/bP - self.wES*self.wIE*self.wSI/bP + self.wES*self.wII*self.wSE/bP - self.wIE*self.wSS*wEI_correction/bP + self.wIS*self.wSE*wEI_correction/bP - self.wEE*self.wII/(bP*bS) + self.wEI*self.wIE/(bP*bS) - self.wIE*wEI_correction/(bP*bS) - self.wEE*self.wPP*self.wSS/bI + self.wEE*self.wPS*self.wSP/bI + self.wEP*self.wPE*self.wSS/bI - self.wEP*self.wPS*self.wSE/bI - self.wES*self.wPE*self.wSP/bI + self.wES*self.wPP*self.wSE/bI - self.wEE*self.wPP/(bI*bS) + self.wEP*self.wPE/(bI*bS) - self.wEE*self.wSS/(bI*bP) + self.wES*self.wSE/(bI*bP) - self.wEE/(bI*bP*bS) + self.wII*self.wPP*self.wSS/bE - self.wII*self.wPS*self.wSP/bE - self.wIP*self.wPI*self.wSS/bE + self.wIP*self.wPS*self.wSI/bE + self.wIS*self.wPI*self.wSP/bE - self.wIS*self.wPP*self.wSI/bE + self.wII*self.wPP/(bE*bS) - self.wIP*self.wPI/(bE*bS) + self.wII*self.wSS/(bE*bP) - self.wIS*self.wSI/(bE*bP) + self.wII/(bE*bP*bS) + self.wPP*self.wSS/(bE*bI) - self.wPS*self.wSP/(bE*bI) + self.wPP/(bE*bI*bS) + self.wSS/(bE*bI*bP) + 1/(bE*bI*bP*bS)

                if abs(det_num) < 1e-10:
                    print('Determinant is zero', end='\r')
                    f_gain_num[i, j] = np.nan
                    f_maxEVs_num[i, j] = np.nan
                    f_maxImEVs_num[i, j] = np.nan
                    f_oscMetric_num[i, j] = np.nan
                    f_modI_rE_num[i, j] = np.nan
                    f_modI_rP_num[i, j] = np.nan
                    f_modI_rS_num[i, j] = np.nan
                    f_modI_rI_num[i, j] = np.nan
                    continue

                L_EE = (self.wII*self.wPP*self.wSS - self.wII*self.wPS*self.wSP - self.wIP*self.wPI*self.wSS + self.wIP*self.wPS*self.wSI + self.wIS*self.wPI*self.wSP - self.wIS*self.wPP*self.wSI + self.wII*self.wPP/bS - self.wIP*self.wPI/bS + self.wII*self.wSS/bP - self.wIS*self.wSI/bP + self.wII/(bP*bS) + self.wPP*self.wSS/bI - self.wPS*self.wSP/bI + self.wPP/(bI*bS) + self.wSS/(bI*bP) + 1/(bI*bP*bS)) / det_num
                L_EP = (self.wEI*self.wIP*self.wSS - self.wEI*self.wIS*self.wSP - self.wEP*self.wII*self.wSS + self.wEP*self.wIS*self.wSI + self.wES*self.wII*self.wSP - self.wES*self.wIP*self.wSI - self.wIP*self.wSS*wEI_correction + self.wIS*self.wSP*wEI_correction + self.wEI*self.wIP/bS - self.wEP*self.wII/bS - self.wIP*wEI_correction/bS - self.wEP*self.wSS/bI + self.wES*self.wSP/bI - self.wEP/(bI*bS)) / det_num
                L_EI = (-self.wEI*self.wPP*self.wSS + self.wEI*self.wPS*self.wSP + self.wEP*self.wPI*self.wSS - self.wEP*self.wPS*self.wSI - self.wES*self.wPI*self.wSP + self.wES*self.wPP*self.wSI + self.wPP*self.wSS*wEI_correction - self.wPS*self.wSP*wEI_correction - self.wEI*self.wPP/bS + self.wEP*self.wPI/bS + self.wPP*wEI_correction/bS - self.wEI*self.wSS/bP + self.wES*self.wSI/bP + self.wSS*wEI_correction/bP - self.wEI/(bP*bS) + wEI_correction/(bP*bS)) / det_num
                L_PI = (self.wEE*self.wPI*self.wSS - self.wEE*self.wPS*self.wSI - self.wEI*self.wPE*self.wSS + self.wEI*self.wPS*self.wSE + self.wES*self.wPE*self.wSI - self.wES*self.wPI*self.wSE + self.wPE*self.wSS*wEI_correction - self.wPS*self.wSE*wEI_correction + self.wEE*self.wPI/bS - self.wEI*self.wPE/bS + self.wPE*wEI_correction/bS - self.wPI*self.wSS/bE + self.wPS*self.wSI/bE - self.wPI/(bE*bS)) / det_num
                L_SI = (-self.wEE*self.wPI*self.wSP + self.wEE*self.wPP*self.wSI + self.wEI*self.wPE*self.wSP - self.wEI*self.wPP*self.wSE - self.wEP*self.wPE*self.wSI + self.wEP*self.wPI*self.wSE - self.wPE*self.wSP*wEI_correction + self.wPP*self.wSE*wEI_correction + self.wEE*self.wSI/bP - self.wEI*self.wSE/bP + self.wSE*wEI_correction/bP + self.wPI*self.wSP/bE - self.wPP*self.wSI/bE - self.wSI/(bE*bP)) / det_num
                L_II = (-self.wEE*self.wPP*self.wSS + self.wEE*self.wPS*self.wSP + self.wEP*self.wPE*self.wSS - self.wEP*self.wPS*self.wSE - self.wES*self.wPE*self.wSP + self.wES*self.wPP*self.wSE - self.wEE*self.wPP/bS + self.wEP*self.wPE/bS - self.wEE*self.wSS/bP + self.wES*self.wSE/bP - self.wEE/(bP*bS) + self.wPP*self.wSS/bE - self.wPS*self.wSP/bE + self.wPP/(bE*bS) + self.wSS/(bE*bP) + 1/(bE*bP*bS)) / det_num
                
                # Jacobian matrix
                J_num = np.array([
                    [(bE*self.wEE - 1)/self.tauE, -bE*self.wEP/self.tauE, -bE*self.wES/self.tauE, bE*(-self.wEI + wEI_correction)/self.tauE],
                    [bP*self.wPE/self.tauP, (-bP*self.wPP - 1)/self.tauP, -bP*self.wPS/self.tauP, -bP*self.wPI/self.tauP],
                    [bS*self.wSE/self.tauS, -bS*self.wSP/self.tauS, (-bS*self.wSS - 1)/self.tauS, -bS*self.wSI/self.tauS],
                    [bI*self.wIE/self.tauI, -bI*self.wIP/self.tauI, -bI*self.wIS/self.tauI, (-bI*self.wII - 1)/self.tauI]
                ])

                # Calculate gain and max eigenvalue
                f_gain_num[i, j] = L_EE * self.I_stim_E + L_EP * self.I_stim_P

                try:
                    f_eigenvals = eig(J_num)[0]
                    f_maxEVs_num[i, j] = np.max(np.real(f_eigenvals))
                    f_maxImEVs_num[i, j] = np.max(np.imag(f_eigenvals))
                    f_maxEVs_index = np.argmax(np.real(f_eigenvals))
                    f_im_2 = np.imag(f_eigenvals[f_maxEVs_index])**2
                    f_re_2 = f_maxEVs_num[i, j]**2
                    f_oscMetric_num[i, j] = f_im_2 / (f_im_2 + f_re_2)
                except:
                    f_maxEVs_num[i, j] = np.nan
                    f_maxImEVs_num[i, j] = np.nan
                    f_oscMetric_num[i, j] = np.nan
                    print(f"Jacobian failed at (rP = {round(rP,2)}, rE = {round(rE,2)}), input: {J_num}")

                # Interneuron modulation effects
                f_modI_rE_num[i, j] = L_EI * self.I_mod_I
                f_modI_rP_num[i, j] = L_PI * self.I_mod_I
                f_modI_rS_num[i, j] = L_SI * self.I_mod_I
                f_modI_rI_num[i, j] = L_II * self.I_mod_I

    ############# Calculate stability and gain after Interneuron modulation
                rS_mod = rS + f_modI_rS_num[i,j]
                rI_mod = rI + f_modI_rI_num[i,j]
                
                # Calculate mathematical correction to W
                if self.interneuron_name == 'ndnf':
                    self.wES = self.wES_0 * np.exp(-rI_mod/alfa)
                    pESN_mod = (self.wES_0 / alfa) * np.exp(-rI_mod/alfa)  # pESN=dW_13/drN, where W_13=-wES
                    wEI_correction_mod = pESN_mod * rS
                    if wEI_correction_mod > self.wEI:
                        wEI_correction_mod = self.wEI
                elif self.interneuron_name == 'vip':
                    wEI_correction_mod = 0
                else: print('False name of the interneuron. Avaliable options: "vip" or "ndnf".')
                
                # Calculate derivatives (gains) after Interneuron modulation
                bE_mod = self.power * self.mult_f * (self.wEE*(rE + f_modI_rE_num[i,j]) - self.wEP*(rP + f_modI_rP_num[i,j]) - self.wES*(rS + f_modI_rS_num[i,j]) - self.wEI*(rI + f_modI_rI_num[i,j]) + xE)**(self.power-1)
                bP_mod = self.power * self.mult_f * (self.wPE*(rE + f_modI_rE_num[i,j]) - self.wPP*(rP + f_modI_rP_num[i,j]) - self.wPS*(rS + f_modI_rS_num[i,j]) - self.wPI*(rI + f_modI_rI_num[i,j]) + xP)**(self.power-1)
                bS_mod = self.power * self.mult_f * (self.wSE*(rE + f_modI_rE_num[i,j]) - self.wSP*(rP + f_modI_rP_num[i,j]) - self.wSS*(rS + f_modI_rS_num[i,j]) - self.wSI*(rI + f_modI_rI_num[i,j]) + xS)**(self.power-1)
                bI_mod = self.power * self.mult_f * (self.wIE*(rE + f_modI_rE_num[i,j]) - self.wIP*(rP + f_modI_rP_num[i,j]) - self.wIS*(rS + f_modI_rS_num[i,j]) - self.wII*(rI + f_modI_rI_num[i,j]) + xI + self.I_mod_I)**(self.power-1)

                # Response matrix calculation
                det_num_mod = -self.wEE*self.wII*self.wPP*self.wSS + self.wEE*self.wII*self.wPS*self.wSP + self.wEE*self.wIP*self.wPI*self.wSS - self.wEE*self.wIP*self.wPS*self.wSI - self.wEE*self.wIS*self.wPI*self.wSP + self.wEE*self.wIS*self.wPP*self.wSI + self.wEI*self.wIE*self.wPP*self.wSS - self.wEI*self.wIE*self.wPS*self.wSP - self.wEI*self.wIP*self.wPE*self.wSS + self.wEI*self.wIP*self.wPS*self.wSE + self.wEI*self.wIS*self.wPE*self.wSP - self.wEI*self.wIS*self.wPP*self.wSE - self.wEP*self.wIE*self.wPI*self.wSS + self.wEP*self.wIE*self.wPS*self.wSI + self.wEP*self.wII*self.wPE*self.wSS - self.wEP*self.wII*self.wPS*self.wSE - self.wEP*self.wIS*self.wPE*self.wSI + self.wEP*self.wIS*self.wPI*self.wSE + self.wES*self.wIE*self.wPI*self.wSP - self.wES*self.wIE*self.wPP*self.wSI - self.wES*self.wII*self.wPE*self.wSP + self.wES*self.wII*self.wPP*self.wSE + self.wES*self.wIP*self.wPE*self.wSI - self.wES*self.wIP*self.wPI*self.wSE - self.wIE*self.wPP*self.wSS*wEI_correction_mod + self.wIE*self.wPS*self.wSP*wEI_correction_mod + self.wIP*self.wPE*self.wSS*wEI_correction_mod - self.wIP*self.wPS*self.wSE*wEI_correction_mod - self.wIS*self.wPE*self.wSP*wEI_correction_mod + self.wIS*self.wPP*self.wSE*wEI_correction_mod - self.wEE*self.wII*self.wPP/bS_mod + self.wEE*self.wIP*self.wPI/bS_mod + self.wEI*self.wIE*self.wPP/bS_mod - self.wEI*self.wIP*self.wPE/bS_mod - self.wEP*self.wIE*self.wPI/bS_mod + self.wEP*self.wII*self.wPE/bS_mod - self.wIE*self.wPP*wEI_correction_mod/bS_mod + self.wIP*self.wPE*wEI_correction_mod/bS_mod - self.wEE*self.wII*self.wSS/bP_mod + self.wEE*self.wIS*self.wSI/bP_mod + self.wEI*self.wIE*self.wSS/bP_mod - self.wEI*self.wIS*self.wSE/bP_mod - self.wES*self.wIE*self.wSI/bP_mod + self.wES*self.wII*self.wSE/bP_mod - self.wIE*self.wSS*wEI_correction_mod/bP_mod + self.wIS*self.wSE*wEI_correction_mod/bP_mod - self.wEE*self.wII/(bP_mod*bS_mod) + self.wEI*self.wIE/(bP_mod*bS_mod) - self.wIE*wEI_correction_mod/(bP_mod*bS_mod) - self.wEE*self.wPP*self.wSS/bI_mod + self.wEE*self.wPS*self.wSP/bI_mod + self.wEP*self.wPE*self.wSS/bI_mod - self.wEP*self.wPS*self.wSE/bI_mod - self.wES*self.wPE*self.wSP/bI_mod + self.wES*self.wPP*self.wSE/bI_mod - self.wEE*self.wPP/(bI_mod*bS_mod) + self.wEP*self.wPE/(bI_mod*bS_mod) - self.wEE*self.wSS/(bI_mod*bP_mod) + self.wES*self.wSE/(bI_mod*bP_mod) - self.wEE/(bI_mod*bP_mod*bS_mod) + self.wII*self.wPP*self.wSS/bE_mod - self.wII*self.wPS*self.wSP/bE_mod - self.wIP*self.wPI*self.wSS/bE_mod + self.wIP*self.wPS*self.wSI/bE_mod + self.wIS*self.wPI*self.wSP/bE_mod - self.wIS*self.wPP*self.wSI/bE_mod + self.wII*self.wPP/(bE_mod*bS_mod) - self.wIP*self.wPI/(bE_mod*bS_mod) + self.wII*self.wSS/(bE_mod*bP_mod) - self.wIS*self.wSI/(bE_mod*bP_mod) + self.wII/(bE_mod*bP_mod*bS_mod) + self.wPP*self.wSS/(bE_mod*bI_mod) - self.wPS*self.wSP/(bE_mod*bI_mod) + self.wPP/(bE_mod*bI_mod*bS_mod) + self.wSS/(bE_mod*bI_mod*bP_mod) + 1/(bE_mod*bI_mod*bP_mod*bS_mod)

                if abs(det_num_mod) < 1e-10:
                    f_gain_num_mod[i, j] = np.nan
                    f_maxEVs_num_mod[i, j] = np.nan
                    f_oscMetric_num_mod[i, j] = np.nan
                    continue

                L_EE_mod = (self.wII*self.wPP*self.wSS - self.wII*self.wPS*self.wSP - self.wIP*self.wPI*self.wSS + self.wIP*self.wPS*self.wSI + self.wIS*self.wPI*self.wSP - self.wIS*self.wPP*self.wSI + self.wII*self.wPP/bS_mod - self.wIP*self.wPI/bS_mod + self.wII*self.wSS/bP_mod - self.wIS*self.wSI/bP_mod + self.wII/(bP_mod*bS_mod) + self.wPP*self.wSS/bI_mod - self.wPS*self.wSP/bI_mod + self.wPP/(bI_mod*bS_mod) + self.wSS/(bI_mod*bP_mod) + 1/(bI_mod*bP_mod*bS_mod)) / det_num_mod
                L_EP_mod = (self.wEI*self.wIP*self.wSS - self.wEI*self.wIS*self.wSP - self.wEP*self.wII*self.wSS + self.wEP*self.wIS*self.wSI + self.wES*self.wII*self.wSP - self.wES*self.wIP*self.wSI - self.wIP*self.wSS*wEI_correction_mod + self.wIS*self.wSP*wEI_correction_mod + self.wEI*self.wIP/bS_mod - self.wEP*self.wII/bS_mod - self.wIP*wEI_correction_mod/bS_mod - self.wEP*self.wSS/bI_mod + self.wES*self.wSP/bI_mod - self.wEP/(bI_mod*bS_mod)) / det_num_mod

                # Jacobian matrix
                J_num_mod = np.array([
                    [(bE_mod*self.wEE - 1)/self.tauE, -bE_mod*self.wEP/self.tauE, -bE_mod*self.wES/self.tauE, bE_mod*(-self.wEI + wEI_correction_mod)/self.tauE],
                    [bP_mod*self.wPE/self.tauP, (-bP_mod*self.wPP - 1)/self.tauP, -bP_mod*self.wPS/self.tauP, -bP_mod*self.wPI/self.tauP],
                    [bS_mod*self.wSE/self.tauS, -bS_mod*self.wSP/self.tauS, (-bS_mod*self.wSS - 1)/self.tauS, -bS_mod*self.wSI/self.tauS],
                    [bI_mod*self.wIE/self.tauI, -bI_mod*self.wIP/self.tauI, -bI_mod*self.wIS/self.tauI, (-bI_mod*self.wII - 1)/self.tauI]
                ])

                # Calculate gain and max eigenvalue
                f_gain_num_mod[i, j] = L_EE_mod * self.I_stim_E + L_EP_mod * self.I_stim_P

                try:
                    f_eigenvals_mod = eig(J_num_mod)[0]
                    f_maxEVs_num_mod[i, j] = np.max(np.real(f_eigenvals_mod))
                    f_maxEVs_index_mod = np.argmax(np.real(f_eigenvals_mod))
                    f_im_2_mod = np.imag(f_eigenvals_mod[f_maxEVs_index_mod])**2
                    f_re_2_mod = f_maxEVs_num_mod[i, j]**2
                    f_oscMetric_num_mod[i, j] = f_im_2_mod / (f_im_2_mod + f_re_2_mod)

                except:
                    f_maxEVs_num_mod[i, j] = np.nan
                    f_oscMetric_num_mod[i, j] = np.nan
                    print(f"Jacobian mod failed at (rP = {round(rP,2)}, rE = {round(rE,2)}), input: {J_num_mod}")

        return f_gain_num, f_gain_num_mod, f_maxEVs_num, f_maxEVs_num_mod, f_maxImEVs_num, f_oscMetric_num, f_oscMetric_num_mod, f_modI_rE_num, f_modI_rP_num, f_modI_rS_num, f_modI_rI_num

##################

    def plot_heatmaps_with_arrows_and_scatter(
            self,
            path_name,
            case_name,
            alfa,
            save_fig=False,
            iso_stab=None,
            iso_gain=None,
            iso_oscIm=None,
            iso_oscdist=None,
            max_rEP=10,
            step_arrows=100,
            nr_points=1000
    ) -> Tuple[Figure, Axes]:
        
        """Plot heatmaps with arrow overlays and scatter plots for a given case"""

        # Calculate dynamics
        gain_num, gain_num_mod, maxEVs_num, maxEVs_num_mod, maxImEVs_num, oscMetric_num, oscMetric_num_mod, modI_rE_num, modI_rP_num, modI_rS_num, modI_rI_num = self.calculate_linear(max_rEP, alfa)

        # Create stability mask
        mask_threshold = -0.05
        maxEVs_num_set0 = np.minimum(maxEVs_num, 0)
        mask01 = np.copy(maxEVs_num)
        mask01[mask01 > mask_threshold] = 0
        mask01[mask01 < mask_threshold] = 1

        # Normalize stability
        norm_maxEVs_num_set0 = mask01 * maxEVs_num_set0
        min_val_EV = np.nanmin(norm_maxEVs_num_set0[norm_maxEVs_num_set0 != 0])
        if min_val_EV != 0:
            norm_maxEVs_num_set0 = norm_maxEVs_num_set0 / min_val_EV

        # Normalize gain
        max_val_gain = np.nanmax(gain_num[mask01 != 0])
        min_val_gain = np.nanmin(gain_num[mask01 != 0])

        norm_gain_num_masked = mask01 * ((gain_num - min_val_gain) / (max_val_gain - min_val_gain))

        # Normalize imaginary part of EVs for oscillation heatmap
        max_val_maxImEVs = np.max(maxImEVs_num)
        if max_val_maxImEVs != 0:
            norm_maxImEVs_num = maxImEVs_num / max_val_maxImEVs
        else:
            norm_maxImEVs_num = 0

        # Create coordinate matrices
        rE_mat, rP_mat = np.meshgrid(self.rE_vec, self.rP_vec, indexing='ij')
        print(rP_mat[::step_arrows].shape)
        print(rE_mat[::step_arrows].shape)

        # Normalize arrow directions
        masked_modI_rP = mask01 * modI_rP_num
        masked_modI_rE = mask01 * modI_rE_num

        # Avoid division by zero
        magnitude = np.sqrt(masked_modI_rP**2 + masked_modI_rE**2)
        magnitude[magnitude == 0] = 1

        norm_masked_modI_rP = masked_modI_rP / magnitude
        norm_masked_modI_rE = masked_modI_rE / magnitude

        # Set arrows to zero where mask is zero
        norm_masked_modI_rP[mask01 == 0] = 0
        norm_masked_modI_rE[mask01 == 0] = 0

        # Normalize the lengths for slope field vectors
        abs_vector = np.sqrt(norm_masked_modI_rP**2 + norm_masked_modI_rE**2)
        delta_rP = norm_masked_modI_rP / abs_vector
        delta_rE = norm_masked_modI_rE / abs_vector

        # Calculate differences for scatter plots
        maxEVs_num_set0_2 = np.minimum(maxEVs_num, 0)
        mask01_2 = maxEVs_num.copy()
        mask01_2[mask01_2 > mask_threshold] = np.nan
        mask01_2[mask01_2 < mask_threshold] = 1

        maxEVs_num_set0_2_mod = np.minimum(maxEVs_num_mod, 0)
        mask01_2_mod = maxEVs_num_mod.copy()
        mask01_2_mod[mask01_2_mod > mask_threshold] = np.nan
        mask01_2_mod[mask01_2_mod < mask_threshold] = 1

        diff_maxEVs_num_set0_2_mod = (mask01_2 * maxEVs_num_set0_2 -
                                       mask01_2_mod * maxEVs_num_set0_2_mod)

        gain_num_set0_2 = mask01_2 * gain_num
        gain_num_set0_2_mod = mask01_2_mod * gain_num_mod
        diff_gain_num_set0_2_mod = gain_num_set0_2_mod - gain_num_set0_2

        oscMetric_num_set0_2 = mask01_2 * oscMetric_num
        oscMetric_num_set0_2_mod = mask01_2 * oscMetric_num_mod
        diff_oscMetric_num_set0_2_mod = oscMetric_num_set0_2_mod - oscMetric_num_set0_2

        # Plot random subset of points
        subset_indices = np.random.choice(diff_maxEVs_num_set0_2_mod.size, nr_points, replace=False)
        subset_i, subset_j = np.unravel_index(subset_indices, diff_maxEVs_num_set0_2_mod.shape)

        subset_indices_osc = np.random.choice(diff_oscMetric_num_set0_2_mod.size, nr_points, replace=False)
        subset_i_osc, subset_j_osc = np.unravel_index(subset_indices_osc, diff_oscMetric_num_set0_2_mod.shape)

    #===============================================================================================

        # Create subplots
        fig, axes = plt.subplots(4, 2, figsize=(12,20))
        ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8 = axes.flatten()
        fig.suptitle(f'{case_name} case')

    ##### Plot 1: Gain with arrows
        im1 = ax1.imshow(norm_gain_num_masked, cmap=self.cmap, aspect='auto',
                         extent=[self.rP_vec[0], self.rP_vec[-1], self.rE_vec[-1], self.rE_vec[0]])

        # Add arrows to gain plot
        ax1.quiver(rP_mat[::step_arrows, ::step_arrows],
                   rE_mat[::step_arrows, ::step_arrows],
                   delta_rP[::step_arrows, ::step_arrows],
                   delta_rE[::step_arrows, ::step_arrows],
                   angles='xy', color='black', alpha=0.7)

        # Plot gain isolines
        if iso_gain != None:
            cont1 = ax1.contour(rP_mat, rE_mat, norm_gain_num_masked, levels=iso_gain, linewidths=1.5)

        ax1.set_title(f'Gain (normalized)')
        ax1.set_xlabel('rP (Hz)')
        ax1.set_ylabel('rE (Hz)')

    ##### Plot 2: Stability with arrows
        im2 = ax2.imshow(norm_maxEVs_num_set0, cmap=self.cmap, aspect='auto',
                         extent=[self.rP_vec[0], self.rP_vec[-1], self.rE_vec[-1], self.rE_vec[0]])

        # Add arrows to stability plot
        ax2.quiver(rP_mat[::step_arrows, ::step_arrows],
                   rE_mat[::step_arrows, ::step_arrows],
                   delta_rP[::step_arrows, ::step_arrows],
                   delta_rE[::step_arrows, ::step_arrows],
                   angles='xy', color='black', alpha=0.7)

        # Plot stability isolines
        if iso_stab != None:
            cont2 = ax2.contour(rP_mat, rE_mat, norm_maxEVs_num_set0, levels=iso_stab, linewidths=1.5)

        ax2.set_title(f'Stability (normalized)')
        ax2.set_xlabel('rP (Hz)')
        ax2.set_ylabel('rE (Hz)')

    ##### Plot 3: Scatter plot for deltas of gain and stability
        for i, j in zip(subset_i, subset_j):
            if not (np.isnan(diff_maxEVs_num_set0_2_mod[i, j]) or np.isnan(diff_gain_num_set0_2_mod[i, j])):
                ax3.scatter(diff_maxEVs_num_set0_2_mod[i, j], diff_gain_num_set0_2_mod[i, j],
                          s=1, c='k', alpha=0.5)

        ax3.axhline(y=0, color='r', linestyle='--', alpha=0.7)
        ax3.axvline(x=0, color='r', linestyle='--', alpha=0.7)
        ax3.set_title('mod ' + self.interneuron_name)
        ax3.set_xlabel(r'$\Delta$ stability')
        ax3.set_ylabel(r'$\Delta$ gain')
        ax3.set_xlim([-0.5, 0.5])
        ax3.set_ylim([-4, 4])


    ##### Plot 4: Info about parameters
        matrix_str = rf'wES factor $\alpha$: {alfa}' + f'\n\n' + f'Random Points: {nr_points}\n\n' + f'SST init. rate: {self.rS0}\n' + self.interneuron_name.upper() + f' init. rate: {self.rI0}\n\n' + f'Stimulation E: {self.I_stim_E}\nStimulation PV: {self.I_stim_P}\nModulation ' + self.interneuron_name.upper() + f': {self.I_mod_I}\n\n' + 'Weight Matrix:\n' + '\n'.join(['  '.join([f'{val:.2f}' for val in row]) for row in self.W])
        ax4.axis('off')  # turn off axes
        ax4.text(0.5, 0.5, matrix_str, fontsize=12, ha='center', va='center', family='monospace')

    ##### Plot 5: The biggest Im{lambda} oscillatory plot
        im5 = ax5.imshow(norm_maxImEVs_num, cmap=self.cmap, aspect='auto',
                         extent=[self.rP_vec[0], self.rP_vec[-1], self.rE_vec[-1], self.rE_vec[0]])

        # Add arrows to plot
        ax5.quiver(rP_mat[::step_arrows, ::step_arrows],
                   rE_mat[::step_arrows, ::step_arrows],
                   delta_rP[::step_arrows, ::step_arrows],
                   delta_rE[::step_arrows, ::step_arrows],
                   angles='xy', color='black', alpha=0.7)

        # Plot isolines
        if iso_oscIm != None:
            cont5 = ax5.contour(rP_mat, rE_mat, norm_maxImEVs_num, levels=iso_oscIm, linewidths=1.5)

        ax5.set_title(r'Max $Im(\lambda)$ (normalized)')
        ax5.set_xlabel('rP (Hz)')
        ax5.set_ylabel('rE (Hz)')

    ##### Plot 6: The real distance oscillatory plot
        im6 = ax6.imshow(oscMetric_num, cmap=self.cmap, aspect='auto',
                         extent=[self.rP_vec[0], self.rP_vec[-1], self.rE_vec[-1], self.rE_vec[0]], vmin=0, vmax=1)

        # Add arrows to plot
        ax6.quiver(rP_mat[::step_arrows, ::step_arrows],
                   rE_mat[::step_arrows, ::step_arrows],
                   delta_rP[::step_arrows, ::step_arrows],
                   delta_rE[::step_arrows, ::step_arrows],
                   angles='xy', color='black', alpha=0.7)

        # Plot isolines
        if iso_oscdist != None:
            cont6 = ax6.contour(rP_mat, rE_mat, oscMetric_num, levels=iso_oscdist, linewidths=1.5)

        ax6.set_title(r'Oscillations: $\frac{|Im(\lambda_{lead})|^2}{|Im(\lambda_{lead})|^2 + |Re(\lambda_{lead})|^2}$')    # r'Distance: $Re(\lambda)_{max} - Re(\lambda_{Im_max})$ (normalized)'
        ax6.set_ylabel('rE (Hz)')
        ax6.set_xlabel('rP (Hz)')

    ##### Plot 7: Scatter plot delta gain and oscillations
        for i, j in zip(subset_i_osc, subset_j_osc):
            if not (np.isnan(diff_oscMetric_num_set0_2_mod[i, j]) or np.isnan(diff_gain_num_set0_2_mod[i, j])):
                ax7.scatter(diff_oscMetric_num_set0_2_mod[i, j], diff_gain_num_set0_2_mod[i, j], s=1, c='k', alpha=0.5)
        ax7.axhline(y=0, color='r', linestyle='--', alpha=0.7)
        ax7.axvline(x=0, color='r', linestyle='--', alpha=0.7)
        ax7.set_title(self.interneuron_name.upper() + ' modulation')
        ax7.set_xlabel(r'$\Delta$ oscillations')
        ax7.set_ylabel(r'$\Delta$ gain')
        ax7.set_xlim([-0.6, 0.6])
        ax7.set_ylim([-4, 4])

    ##### Plot 8: Scatter plot delta stability and oscillations
        for i, j in zip(subset_i_osc, subset_j_osc):
            if not (np.isnan(diff_oscMetric_num_set0_2_mod[i, j]) or np.isnan(diff_maxEVs_num_set0_2_mod[i, j])):
                ax8.scatter(diff_oscMetric_num_set0_2_mod[i, j], diff_maxEVs_num_set0_2_mod[i, j], s=1, c='k', alpha=0.5)
        ax8.axhline(y=0, color='r', linestyle='--', alpha=0.7)
        ax8.axvline(x=0, color='r', linestyle='--', alpha=0.7)
        ax8.set_title(self.interneuron_name.upper() + ' modulation')
        ax8.set_xlabel(r'$\Delta$ oscillations')
        ax8.set_ylabel(r'$\Delta$ stability')
        ax8.set_xlim([-0.6, 0.6])
        ax8.set_ylim([-0.5, 0.5])

        # Add colorbars
        plt.colorbar(im1, ax=ax1, shrink=0.8)
        plt.colorbar(im2, ax=ax2, shrink=0.8)
        plt.colorbar(im5, ax=ax5, shrink=0.8)
        plt.colorbar(im6, ax=ax6, shrink=0.8)

        # Define the path for the new folder
        folder = Path(fr'{self.interneuron_name}\{path_name}')

        # Create the folder
        folder.mkdir(parents=True, exist_ok=True)
        
        plt.tight_layout()
        if save_fig:
            plt.savefig(fr'{folder}\{case_name}.svg')
        plt.show()

        return fig, axes
    
########################

    def calculate_dynamics(
            self,
            rE0,
            rP0,
            alfa,
            dt,
            end_sim,
            OnlyDeltas=False
    ):
        
        """
        Calculates the real dynamics of the system with fixed SST and NDNF/VIP rates and passed E and PV rates. Also can calculate only changes of steady state rates before and after NDNF/VIP modulation.
        """

        tau = np.array([[self.tauE], [self.tauP], [self.tauS], [self.tauI]])
        
        # Function to run the simulation
        W = self.W.copy() * np.array([1, -1, -1, -1])  # to account for inhibitory weights sign
        W[0,2] = -self.wES_0 * np.exp(-self.rI0 / alfa)
        r = np.array([[rE0], [rP0], [self.rS0], [self.rI0]], dtype=float)
        I_mod = np.array([[0], [0], [0], [self.I_mod_I]], dtype=float)
                
        I0 = (r / self.mult_f) ** (1 / self.power) - np.dot(W, r)
        I = I0
        
        # Set initial values of the rates and weights
        if not OnlyDeltas:
            rE_save = [r[0,0]]
            rP_save = [r[1,0]]
            rS_save = [r[2,0]]
            rI_save = [r[3,0]]
            wES_save = [-W[0,2]]
            wIS_save = [-W[3,2]]

        for t in np.arange(dt, end_sim+dt, dt):
            if np.isclose(t, 50, atol=dt):
                I += I_mod
            
            q = np.dot(W, r) + I
            r += dt * ((self.mult_f * (q ** self.power)) - r) / tau
            W[0,2] = -self.wES_0 * np.exp(-r[3,0] / alfa)
    
            if not OnlyDeltas:
                rE_save.append(r[0,0])
                rP_save.append(r[1,0])
                rS_save.append(r[2,0])
                rI_save.append(r[3,0])
                wES_save.append(-W[0,2])
                wIS_save.append(-W[3,2])

        delta_r = r - np.array([[rE0], [rP0], [self.rS0], [self.rI0]], dtype=float)

        if not OnlyDeltas:
            result = (rE_save, rP_save, rS_save, rI_save, wES_save, wIS_save, delta_r.reshape(4))
        else:
            result = delta_r.reshape(4)
                       
        return result
    
#####################
        
    def plot_rate_weight_dynamics(
        self,
        rE0,
        rP0,
        alfa,
        path_name=None,
        case_name=None,
        dt=0.01,
        end_sim=350,
        save_fig=False
    ) -> Tuple[Figure, Axes]:
        
        """
        Plots the real dynamics of rates and weights for particular system
        """

        rE_save, rP_save, rS_save, rI_save, wES_save, wIS_save, delta_r_vec = self.calculate_dynamics(rE0, rP0, alfa, dt, end_sim)

        time = np.arange(0, end_sim+dt, dt)

        fig, ax = plt.subplots(2,1, figsize=(10,8))
        ax[0].plot(time, rE_save, 'r', label='rE')
        ax[0].plot(time, rP_save, 'b', label='rP')
        ax[0].plot(time, rS_save, 'g', label='rS')
        ax[0].plot(time, rI_save, 'm', label='rI')
        ax[0].set_xlabel("Time (s)")
        ax[0].set_ylabel("rX (Hz)")
        ax[0].legend().set_title(rf"$\alpha$ = {alfa}")
        ax[0].set_xlim([0, end_sim])

        ax[1].plot(time, wES_save, 'r', label=r'$w_{ES}$')
        ax[1].plot(time, wIS_save, 'm', label=r'$w_{IS}$')
        ax[1].set_xlabel("Time (s)")
        ax[1].set_ylabel("Weight")
        ax[1].legend().set_title(rf"$\alpha$ = {alfa}")
        ax[1].set_xlim([0, end_sim])

        # Define the path for the new folder
        folder = Path(fr'{self.interneuron_name}\{path_name}')
        # Create the folder
        folder.mkdir(parents=True, exist_ok=True)
        
        plt.tight_layout()
        if save_fig:
            plt.savefig(fr'{folder}\{case_name}.svg')
        plt.show()
        
        return fig, ax
    
#############################

    def plot_drE_rP0_fix_alpha(
            self,
            alfa,
            rE0,
            rP0_vec=np.arange(0, 201, 1),
            dt=0.01,
            end_sim=350,
            progress=True
    ) -> Tuple[Figure, Axes]:
        
        """Plots excitatory neuron population rate change against initial rate of PV value for fixed alpha and initial rate of E value"""

        delta_rE_save = np.zeros_like(rP0_vec, dtype=float)

        for i, rP0 in enumerate(rP0_vec):
            delta_rE_save[i] = self.calculate_dynamics(rE0, rP0, alfa, dt, end_sim, OnlyDeltas=True)[0]
            
            # Progress bar
            if progress:
                perc = 100*i/len(rP0_vec)
                if not perc%5:
                    print(f'Finished {np.floor(perc)}%', end="\r")

        x_values = rP0_vec

        fig, ax = plt.subplots(figsize=(10,4))
        ax.scatter(x_values, delta_rE_save[:], c="#610000", marker=".", s=10)
        ax.set_xlabel(r'$r_{P_0} \, [Hz]$')
        ax.set_ylabel(r'$\Delta r_E \, [Hz]$')
        ax.legend(loc='upper left', bbox_to_anchor=(1.07, 0.5)).set_title(fr'$r_E0 = {rE0}$, $\alpha = {alfa}$' + '\n' + fr'$r_S0 = {self.rS0}$, $r_I0 = {self.rI0}$')

        return fig, ax
    
#############################     

    def plot_rate_change_alfa(
            self,
            rE0,
            rP0,
            alfa_vec,
            dt=0.01,
            end_sim=350,
            LinearCompare=False,
            OnlyExcitatory=True,
            margins=0.1,
            progress=True
    ) -> Tuple[Figure, Axes]:
        
        """
        Plots the rates changes against the alfa factor 
        """

        delta_r_save = np.zeros((4, len(alfa_vec)))
        for i, alfa in enumerate(alfa_vec):
            delta_r_save[:,i] = self.calculate_dynamics(rE0, rP0, alfa, dt, end_sim, OnlyDeltas=True)
            
            # Progress bar
            if progress:
                perc = 100*i/len(alfa_vec)
                if not perc%10:
                    print(f'Finished {np.floor(perc)}%', end="\r")

        if LinearCompare:
            L_save = np.zeros((4, len(alfa_vec)))
            linear_delta_r_save = np.zeros((4, len(alfa_vec)))
            for i, alfa in enumerate(alfa_vec):
                rS = self.rS0
                rI = self.rI0
                rE = rE0
                rP = rP0

                # Calculate mathematical correction to W due to perturbation and weight-rate dependency of wES
                if self.interneuron_name == 'ndnf':
                    self.wES = self.wES_0 * np.exp(-rI/alfa)
                    pESN = (self.wES_0/alfa) * np.exp(-rI/alfa) # pESN=dW_13/drN, where W_13=-wES
                    wEI_correction = pESN * rS
                    if wEI_correction > self.wEI:  # we want to assume the sign of W*_14+correction, where W*_14=-wEI
                        wEI_correction = self.wEI
                elif self.interneuron_name == 'vip':
                    wEI_correction = 0
                else: print('False name of the interneuron. Avaliable options: "vip" or "ndnf".')
                
                # Calculate inputs
                xE = (rE/self.mult_f)**(1/self.power) - (self.wEE*rE - self.wEP*rP - self.wES*rS - self.wEI*rI)
                xP = (rP/self.mult_f)**(1/self.power) - (self.wPE*rE - self.wPP*rP - self.wPS*rS - self.wPI*rI)
                xS = (rS/self.mult_f)**(1/self.power) - (self.wSE*rE - self.wSP*rP - self.wSS*rS - self.wSI*rI)
                xI = (rI/self.mult_f)**(1/self.power) - (self.wIE*rE - self.wIP*rP - self.wIS*rS - self.wII*rI)

                # Calculate derivatives (gains)
                bE = self.power * self.mult_f * (self.wEE*rE - self.wEP*rP - self.wES*rS - self.wEI*rI + xE)**(self.power-1)
                bP = self.power * self.mult_f * (self.wPE*rE - self.wPP*rP - self.wPS*rS - self.wPI*rI + xP)**(self.power-1)
                bS = self.power * self.mult_f * (self.wSE*rE - self.wSP*rP - self.wSS*rS - self.wSI*rI + xS)**(self.power-1)
                bI = self.power * self.mult_f * (self.wIE*rE - self.wIP*rP - self.wIS*rS - self.wII*rI + xI)**(self.power-1)

                # Response matrix terms calculation
                det_num = -self.wEE*self.wII*self.wPP*self.wSS + self.wEE*self.wII*self.wPS*self.wSP + self.wEE*self.wIP*self.wPI*self.wSS - self.wEE*self.wIP*self.wPS*self.wSI - self.wEE*self.wIS*self.wPI*self.wSP + self.wEE*self.wIS*self.wPP*self.wSI + self.wEI*self.wIE*self.wPP*self.wSS - self.wEI*self.wIE*self.wPS*self.wSP - self.wEI*self.wIP*self.wPE*self.wSS + self.wEI*self.wIP*self.wPS*self.wSE + self.wEI*self.wIS*self.wPE*self.wSP - self.wEI*self.wIS*self.wPP*self.wSE - self.wEP*self.wIE*self.wPI*self.wSS + self.wEP*self.wIE*self.wPS*self.wSI + self.wEP*self.wII*self.wPE*self.wSS - self.wEP*self.wII*self.wPS*self.wSE - self.wEP*self.wIS*self.wPE*self.wSI + self.wEP*self.wIS*self.wPI*self.wSE + self.wES*self.wIE*self.wPI*self.wSP - self.wES*self.wIE*self.wPP*self.wSI - self.wES*self.wII*self.wPE*self.wSP + self.wES*self.wII*self.wPP*self.wSE + self.wES*self.wIP*self.wPE*self.wSI - self.wES*self.wIP*self.wPI*self.wSE - self.wIE*self.wPP*self.wSS*wEI_correction + self.wIE*self.wPS*self.wSP*wEI_correction + self.wIP*self.wPE*self.wSS*wEI_correction - self.wIP*self.wPS*self.wSE*wEI_correction - self.wIS*self.wPE*self.wSP*wEI_correction + self.wIS*self.wPP*self.wSE*wEI_correction - self.wEE*self.wII*self.wPP/bS + self.wEE*self.wIP*self.wPI/bS + self.wEI*self.wIE*self.wPP/bS - self.wEI*self.wIP*self.wPE/bS - self.wEP*self.wIE*self.wPI/bS + self.wEP*self.wII*self.wPE/bS - self.wIE*self.wPP*wEI_correction/bS + self.wIP*self.wPE*wEI_correction/bS - self.wEE*self.wII*self.wSS/bP + self.wEE*self.wIS*self.wSI/bP + self.wEI*self.wIE*self.wSS/bP - self.wEI*self.wIS*self.wSE/bP - self.wES*self.wIE*self.wSI/bP + self.wES*self.wII*self.wSE/bP - self.wIE*self.wSS*wEI_correction/bP + self.wIS*self.wSE*wEI_correction/bP - self.wEE*self.wII/(bP*bS) + self.wEI*self.wIE/(bP*bS) - self.wIE*wEI_correction/(bP*bS) - self.wEE*self.wPP*self.wSS/bI + self.wEE*self.wPS*self.wSP/bI + self.wEP*self.wPE*self.wSS/bI - self.wEP*self.wPS*self.wSE/bI - self.wES*self.wPE*self.wSP/bI + self.wES*self.wPP*self.wSE/bI - self.wEE*self.wPP/(bI*bS) + self.wEP*self.wPE/(bI*bS) - self.wEE*self.wSS/(bI*bP) + self.wES*self.wSE/(bI*bP) - self.wEE/(bI*bP*bS) + self.wII*self.wPP*self.wSS/bE - self.wII*self.wPS*self.wSP/bE - self.wIP*self.wPI*self.wSS/bE + self.wIP*self.wPS*self.wSI/bE + self.wIS*self.wPI*self.wSP/bE - self.wIS*self.wPP*self.wSI/bE + self.wII*self.wPP/(bE*bS) - self.wIP*self.wPI/(bE*bS) + self.wII*self.wSS/(bE*bP) - self.wIS*self.wSI/(bE*bP) + self.wII/(bE*bP*bS) + self.wPP*self.wSS/(bE*bI) - self.wPS*self.wSP/(bE*bI) + self.wPP/(bE*bI*bS) + self.wSS/(bE*bI*bP) + 1/(bE*bI*bP*bS)

                L_save[0,i] = (-self.wEI*self.wPP*self.wSS + self.wEI*self.wPS*self.wSP + self.wEP*self.wPI*self.wSS - self.wEP*self.wPS*self.wSI - self.wES*self.wPI*self.wSP + self.wES*self.wPP*self.wSI + self.wPP*self.wSS*wEI_correction - self.wPS*self.wSP*wEI_correction - self.wEI*self.wPP/bS + self.wEP*self.wPI/bS + self.wPP*wEI_correction/bS - self.wEI*self.wSS/bP + self.wES*self.wSI/bP + self.wSS*wEI_correction/bP - self.wEI/(bP*bS) + wEI_correction/(bP*bS)) / det_num
                linear_delta_r_save[0,i] = L_save[0,i] * self.I_mod_I                
                
                if not OnlyExcitatory:
                    L_save[1,i] = (self.wEE*self.wPI*self.wSS - self.wEE*self.wPS*self.wSI - self.wEI*self.wPE*self.wSS + self.wEI*self.wPS*self.wSE + self.wES*self.wPE*self.wSI - self.wES*self.wPI*self.wSE + self.wPE*self.wSS*wEI_correction - self.wPS*self.wSE*wEI_correction + self.wEE*self.wPI/bS - self.wEI*self.wPE/bS + self.wPE*wEI_correction/bS - self.wPI*self.wSS/bE + self.wPS*self.wSI/bE - self.wPI/(bE*bS)) / det_num
                    L_save[2,i] = (-self.wEE*self.wPI*self.wSP + self.wEE*self.wPP*self.wSI + self.wEI*self.wPE*self.wSP - self.wEI*self.wPP*self.wSE - self.wEP*self.wPE*self.wSI + self.wEP*self.wPI*self.wSE - self.wPE*self.wSP*wEI_correction + self.wPP*self.wSE*wEI_correction + self.wEE*self.wSI/bP - self.wEI*self.wSE/bP + self.wSE*wEI_correction/bP + self.wPI*self.wSP/bE - self.wPP*self.wSI/bE - self.wSI/(bE*bP)) / det_num
                    L_save[3,i] = (-self.wEE*self.wPP*self.wSS + self.wEE*self.wPS*self.wSP + self.wEP*self.wPE*self.wSS - self.wEP*self.wPS*self.wSE - self.wES*self.wPE*self.wSP + self.wES*self.wPP*self.wSE - self.wEE*self.wPP/bS + self.wEP*self.wPE/bS - self.wEE*self.wSS/bP + self.wES*self.wSE/bP - self.wEE/(bP*bS) + self.wPP*self.wSS/bE - self.wPS*self.wSP/bE + self.wPP/(bE*bS) + self.wSS/(bE*bP) + 1/(bE*bP*bS)) / det_num

                    linear_delta_r_save[1,i] = L_save[1,i] * self.I_mod_I
                    linear_delta_r_save[2,i] = L_save[2,i] * self.I_mod_I
                    linear_delta_r_save[3,i] = L_save[3,i] * self.I_mod_I

        x_values = alfa_vec

        fig, ax = plt.subplots(figsize=(10,4))
        ax.scatter(x_values, delta_r_save[0,:], c="#610000", marker="v", s=10, label=r'$\Delta r_E$')
        if not OnlyExcitatory:
            ax.scatter(x_values, delta_r_save[1,:], c='#000066', marker="v", label=r'$\Delta r_P$')
            ax.scatter(x_values, delta_r_save[2,:], c='#009900', marker="v", label=r'$\Delta r_S$')
            ax.scatter(x_values, delta_r_save[3,:], c="#782747", marker="v", label=r'$\Delta r_I$')
        ax.set_xlabel(r'$\alpha$')
        ax.set_ylabel(r'$\Delta r_X \, [Hz]$')
        
        if LinearCompare:
            axL = ax.twinx()
            axL.scatter(x_values, linear_delta_r_save[0,:], c="#ff0000", marker=".", label=r'$\Delta r_E\,linear$')
            if not OnlyExcitatory:
                axL.scatter(x_values, linear_delta_r_save[1,:], c='#1a1aff', marker=".", label=r'$L_{PI}$')
                axL.scatter(x_values, linear_delta_r_save[2,:], c="#00CD00", marker=".", label=r'$L_{SI}$')
                axL.scatter(x_values, linear_delta_r_save[3,:], c="#d6457f", marker=".", label=r'$L_{II}$')
            axL.set_ylabel(r'$\Delta r_X$ for linearized system')
            (a1,a2), (l1,l2) = ax.get_ylim(), axL.get_ylim()
            axL.set_ylim(min(a1,l1)-margins, max(a2,l2)+margins)
            ax.set_ylim(min(a1,l1)-margins, max(a2,l2)+margins)
            lines2, labels2 = axL.get_legend_handles_labels()
            lines1, labels1 = ax.get_legend_handles_labels()
            ax.legend(lines1 + lines2, labels1 + labels2, loc="center left", bbox_to_anchor=(1.07, 0.5)).set_title(fr'$r_E0 = {rE0}$, $r_P0 = {rP0}$' + '\n' + fr'$r_S0 = {self.rS0}$, $r_I0 = {self.rI0}$')
        else: ax.legend(loc="center left", bbox_to_anchor=(1.07, 0.5)).set_title(fr'$r_E0 = {rE0}$, $r_P0 = {rP0}$' + '\n' + fr'$r_S0 = {self.rS0}$, $r_I0 = {self.rI0}$')

        return fig, ax
    
#######################################

    def plot_rates_for_alphas(self,
                              rE0,
                              rP0,
                              alfa_vec,
                              dt=0.01,
                              end_sim=350
    ) -> Tuple[Figure, Axes]:
        """
        Plots rate dynamics for several different alphas.
        """
        mash_rows = int(np.ceil(len(alfa_vec)/2))
        fig, ax = plt.subplots(mash_rows, 2, figsize=(12,mash_rows*3))
        time = np.arange(0, end_sim+dt, dt)

        for i, alfa in enumerate(alfa_vec):
            rE_save, rP_save, rS_save, rI_save, wES_save, wIS_save, delta_r_vec = self.calculate_dynamics(rE0, rP0, alfa, dt, end_sim)
            ax.flatten()[i].plot(time, rE_save, 'r', label=r'$r_E$')
            ax.flatten()[i].plot(time, rP_save, 'b', label=r'$r_P$')
            ax.flatten()[i].plot(time, rS_save, 'g', label=r'$r_S$')
            ax.flatten()[i].plot(time, rI_save, 'm', label=r'$r_I$')
            ax.flatten()[i].set_xlabel("Time (s)")
            ax.flatten()[i].set_ylabel(r"$r_X$ (Hz)")
            ax.flatten()[i].legend().set_title(rf"$\alpha$ = {alfa}")
            ax.flatten()[i].set_xlim([0, end_sim])
        
        plt.tight_layout()
        
        return fig, ax
        


## Run script

In [2]:
iso_gain_num = [0, 0.2, 0.3, 0.4]
iso_stab_num = [0, 0.5, 0.7, 1]
iso_oscIm_num = [0, 0.5, 0.7, 1]
iso_oscdist_num = [0, 0.5, 0.7, 1]

# W = np.array([[0.8, -1, -0.9, -0.1],
#                  [1, -0.6, -0.3, -0],
#                  [0.2, -0, -0, -0.4],
#                  [0.2, -0.3, -0.4, -0]])

In [189]:
W_inh = np.array([
    [0.8, 0.5, 0.3, 0.3],
    [1, 0.6, 0.8, 0],
    [0.2, 0, 0, 0],
    [0, 0, 0.3, 0]
])
W_disinh = np.array([
    [0.8, 0.5, 0.3, 0],
    [1, 0.6, 0.8, 0.3],
    [0.2, 0, 0, 0],
    [0, 0, 0.3, 0]
])
W_both = np.array([
    [0.8, 0.5, 0.3, 0.3],
    [1, 0.6, 0.8, 0.3],
    [0.2, 0, 0, 0],
    [0, 0, 0.3, 0]
])

alfa_num = 1
path_name_num = fr'test_weight_depend'

In [190]:
Circuit = ExtendedCircuit('ndnf', W_inh, step_rX=0.05)

In [None]:
heatmap = Circuit.plot_heatmaps_with_arrows_and_scatter(path_name_num, f'alpha{alfa_num}_base_inhibitory', alfa=alfa_num, save_fig=False, step_arrows=10, nr_points=1000)

In [None]:
dynamics = Circuit.plot_rate_weight_dynamics(rE0=0.1, rP0=5, alfa=3.1, path_name=path_name_num, case_name=f'alpha{alfa_num}_Imod', dt=0.1, end_sim=600, save_fig=False)

In [None]:
inv_alfa_vec = np.arange(0.1, 10.1, 1)
rate_alfa_fig, rate_alfa_ax = Circuit.plot_rate_change_alfa(rE0=0.1, rP0=5, inv_alfa_vec=inv_alfa_vec, end_sim=450, LinearCompare=False, margins=0.1)

In [154]:
inv_alfa_vec = np.arange(0.1, 7, 1)

for rE0_num in np.arange(3.1, 4, 1):
    for rP0_num in np.arange(20, 100, 2):
        print(f'rE={rE0_num} and rP={rP0_num}', end='\r')
        
        fig, ax = Circuit.plot_rate_change_alfa(
            rE0=rE0_num,
            rP0=rP0_num,
            inv_alfa_vec=inv_alfa_vec,
            end_sim=450,
            LinearCompare=False,
            margins=0.1,
            progress=False
        )

        if np.all(Circuit.W == W_inh):
            regime = 'inh'
        elif np.all(Circuit.W == W_disinh):
            regime = 'disinh'

        file_path = fr'ndnf\dr_alpha_graph_rErPchange\{regime}_rE_{rE0_num}_rP_{rP0_num}.png'
        fig.savefig(file_path, dpi=300, bbox_inches="tight")
        plt.close(fig)

print('\n'+'Finished!')

rE=3.1 and rP=98
Finished!


In [None]:
drE_rP0_fig, drE_rP0_ax = Circuit.plot_drE_rP0_fix_alpha(alfa=3.1,
                                                         rE0=0.1,
                                                         rP0_vec=np.arange(0, 802, 2),
                                                         end_sim=450)
drE_rP0_ax.axhline(y=0, linestyle='--', linewidth=0.5, color='red')
drE_rP0_fig.show()

# Maximizing $L_{EN}$ ($\Delta r_E$) by minimizing $w_{EI}, \, w_{PP}, \, r_P$

In [193]:
W_drE_max_wEI_change = np.array([
    [0.8, 0.5, 0.3, 0.1],
    [1, 0.6, 0.8, 0],
    [0.2, 0, 0, 0],
    [0, 0, 0.3, 0]
])

W_drE_max_wPP_change = np.array([
    [0.8, 0.5, 0.3, 0.3],
    [1, 0.1, 0.8, 0],
    [0.2, 0, 0, 0],
    [0, 0, 0.3, 0]
])

In [204]:
CircuitMaxRate_inh = ExtendedCircuit('ndnf', W_drE_max_wPP_change, step_rX=0.05)

In [None]:
inv_alfa_vec_max_rate = np.arange(0.1, 1.3, 0.1)
CircuitMaxRate_inh.plot_rate_change_alfa(rE0=0.01, rP0=200, inv_alfa_vec=inv_alfa_vec_max_rate, end_sim=450, LinearCompare=False, margins=0.1)

In [None]:
CircuitMaxRate_inh.plot_rate_weight_dynamics(rE0=0.01, rP0=200, alfa=3.1, path_name=path_name_num, case_name=f'alpha{alfa_num}_Imod', dt=0.1, end_sim=600, save_fig=False)

In [2]:
from pathlib import Path
import yaml
import numpy as np

# ---------- Load YAML ----------
cfg_path = Path("configs/reference_dynamics.yaml")
cfg = yaml.safe_load(cfg_path.read_text())

c = cfg["circuit"]
g = cfg["grid"]
o = cfg["output"]

# ---------- Parse inputs ----------
W = np.array(c["W"], dtype=float)
tau = np.array(c.get("tau", [10, 10, 10, 10]), dtype=float)

interneuron_name = c["interneuron_name"]
rS0 = float(c.get("rS0", 2.0))
rI0 = float(c.get("rI0", 2.0))
step_rX = float(c.get("step_rX", 0.01))

I_stim_E = float(c.get("I_stim_E", 0.3))
I_stim_P = float(c.get("I_stim_P", 0.3))
I_mod_I = float(c.get("I_mod_I", 0.3))

power = float(c.get("power", 2.0))
mult_f = float(c.get("mult_f", 0.25))

max_rEP = float(g["max_rEP"])
alfa = float(g["alfa"])

out_npz = Path(o["npz_path"])
out_npz.parent.mkdir(parents=True, exist_ok=True)

print("Loaded config:", cfg["experiment"]["name"])
print("W shape:", W.shape, "tau:", tau, "step_rX:", step_rX)

# ---------- Create circuit ----------
circuit = ExtendedCircuit(
    interneuron_name=interneuron_name,
    W=W,
    tau=tau,
    rS0=rS0,
    rI0=rI0,
    step_rX=step_rX,
    I_stim_E=I_stim_E,
    I_stim_P=I_stim_P,
    I_mod_I=I_mod_I,
    power=power,
    mult_f=mult_f
)

# ---------- Run reference grid ----------
(
    f_gain_num,
    f_gain_num_mod,
    f_maxEVs_num,
    f_maxEVs_num_mod,
    f_maxImEVs_num,
    f_oscMetric_num,
    f_oscMetric_num_mod,
    f_modI_rE_num,
    f_modI_rP_num,
    f_modI_rS_num,
    f_modI_rI_num,
) = circuit.calculate_linear(max_rEP=max_rEP, alfa=alfa)

# ---------- Save reference artifact ----------
np.savez(
    out_npz,
    # inputs
    cfg_yaml=np.array(cfg_path.read_text(), dtype=object),
    interneuron_name=np.array(interneuron_name, dtype=object),
    W=W,
    tau=tau,
    rS0=np.array(rS0),
    rI0=np.array(rI0),
    step_rX=np.array(step_rX),
    I_stim_E=np.array(I_stim_E),
    I_stim_P=np.array(I_stim_P),
    I_mod_I=np.array(I_mod_I),
    power=np.array(power),
    mult_f=np.array(mult_f),
    max_rEP=np.array(max_rEP),
    alfa=np.array(alfa),

    # grid axes (as produced by your function)
    rE_vec=np.array(circuit.rE_vec, dtype=float),
    rP_vec=np.array(circuit.rP_vec, dtype=float),

    # outputs (the “golden” arrays)
    f_gain_num=f_gain_num,
    f_gain_num_mod=f_gain_num_mod,
    f_maxEVs_num=f_maxEVs_num,
    f_maxEVs_num_mod=f_maxEVs_num_mod,
    f_maxImEVs_num=f_maxImEVs_num,
    f_oscMetric_num=f_oscMetric_num,
    f_oscMetric_num_mod=f_oscMetric_num_mod,
    f_modI_rE_num=f_modI_rE_num,
    f_modI_rP_num=f_modI_rP_num,
    f_modI_rS_num=f_modI_rS_num,
    f_modI_rI_num=f_modI_rI_num,
)

print("Saved reference artifact to:", out_npz)
print("Shapes:",
      "gain", f_gain_num.shape,
      "maxEV", f_maxEVs_num.shape,
      "rE_vec", circuit.rE_vec.shape,
      "rP_vec", circuit.rP_vec.shape)
print("Sanity means:",
      "gain mean =", float(np.nanmean(f_gain_num)),
      "maxEV mean =", float(np.nanmean(f_maxEVs_num)))


Loaded config: reference_dynamics
W shape: (4, 4) tau: [10. 10. 10. 10.] step_rX: 0.5
Saved reference artifact to: reference\reference_dynamics.npz
Shapes: gain (4, 4) maxEV (4, 4) rE_vec (4,) rP_vec (4,)
Sanity means: gain mean = 0.6042729262569879 maxEV mean = -0.04373754904799867


In [4]:
z = np.load("reference/reference_dynamics.npz", allow_pickle=True)
print("Keys:", z.files[:10], "... total:", len(z.files))
print("gain shape:", z["f_gain_num"].shape, "mean:", float(np.nanmean(z["f_gain_num"])))
print("maxEV shape:", z["f_maxEVs_num"].shape, "mean:", float(np.nanmean(z["f_maxEVs_num"])))
print("rE_vec:", z["rE_vec"][:5], "...", z["rE_vec"][-5:])
print("rP_vec:", z["rP_vec"][:5], "...", z["rP_vec"][-5:])


Keys: ['cfg_yaml', 'interneuron_name', 'W', 'tau', 'rS0', 'rI0', 'step_rX', 'I_stim_E', 'I_stim_P', 'I_mod_I'] ... total: 27
gain shape: (4, 4) mean: 0.6042729262569879
maxEV shape: (4, 4) mean: -0.04373754904799867
rE_vec: [2.  1.5 1.  0.5] ... [2.  1.5 1.  0.5]
rP_vec: [0.5 1.  1.5 2. ] ... [0.5 1.  1.5 2. ]
