### This script is for simulation of the simlified four-population model

- Units: <br />
Pyr, PV, SST and VIP

- Connections:<br />
Pyr targets: Pyr, PV, SST, VIP<br />
PV targets: Pyr, PV, VIP<br />
SST targets: Pyr, PV, VIP<br />
VIP targets: SST<br />

- Euqations:<br />

Pyr
$$\tau_E\frac{dr_E}{dt}=-r_E+\Phi_E[(W_{EE} - W_{EE}^\star) r_E - (W_{ES} - W_{ES}^\star) r_S + J_E]$$

SST control
$$\tau_S\frac{dr_S}{dt}=-r_S+\Phi_S[(W_{SE} - W_{SE}^\star)r_E + W_{SS}^\star r_S + J_S]$$

SST DART
$$\tau_S\frac{dr_S}{dt}=-r_S+\Phi_S[((1 - x )W_{SE} - W_{SE}^\star)r_E + W_{SS}^\star r_S + J_S]$$


In [92]:

import numpy as np
import matplotlib.pyplot as plt

# Figure configurations
# remove the top and right spines from plot in the global plt setting
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False

# change the linewidth of the axes and spines
plt.rcParams["axes.linewidth"] = 2
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["xtick.major.size"] = 10
plt.rcParams["xtick.major.width"] = 2
plt.rcParams["ytick.major.size"] = 10
plt.rcParams["ytick.major.width"] = 2
plt.rcParams["xtick.minor.size"] = 5
plt.rcParams["xtick.minor.width"] = 2
plt.rcParams["ytick.minor.size"] = 5
plt.rcParams["ytick.minor.width"] = 2

# change the fontsize of the ticks label
plt.rcParams["xtick.labelsize"] = 20
plt.rcParams["ytick.labelsize"] = 20

# change the fontsize of the axes label
plt.rcParams["axes.labelsize"] = 20
# change the fontsize of the legend
plt.rcParams["legend.fontsize"] = 20
# change the fontsize of the title
plt.rcParams["axes.titlesize"] = 20
# change the title font size
plt.rcParams["font.size"] = 20

# change the font family to Arial
plt.rcParams["font.family"] = "Arial"

In [93]:
# Variables initialization

# Adjustable 
max_time = 200  # total time 2000 ms
tau_e, tau_s = 1, 1  # time constants
gain_e, gain_s = 1.289, 1.481  # gains of the quadratic transfer function
re_start = rs_start = 0.0  # initial values of the firing rates


# Dynamics
dt = 0.1  # time step 0.1 ms
T = int(max_time/dt)  # total steps
stim_start = np.int32(0.25 * T)  # start time of the stimulus
stim_end = np.int32(0.75 * T)  # end time of the stimulus

# Units
neuron_types = ['Pyr', 'SST']  # neuron types of the simplified model
rEdatacontrol = np.array([[0.1012, 0.1032, 0.1254], [0.1824, 0.1939, 0.2282]])
rEdataDART = np.array([[0.1190, 0.1304, 0.1593], [0.2021, 0.2235, 0.2665]])
rSdatacontrol = np.array([[0.0539, 0.0728, 0.1014], [0.1285, 0.1688, 0.2091]])
rSdataDART = np.array([[0.0246, 0.0463, 0.0883], [0.1116, 0.1717, 0.2423]])

# Conditions
states = ['Stationary', 'Running']  # states of the mice
contrasts = ['25%', '50%', '100%']  # contrast levels

contrast_values = [0.25, 0.5, 1.0]  # contrast values

# Connectivity
# 11 parameters from fitting the model to the data
# 5 weights for each state and 1 DART effect

# all weights vary
params = [1.11255061,  0.1546616,   2.18953047, -2.88933692,  1.14327749,
        0.9575452,   0.07522963,  0.33622663, -0.96945827, 0.72326915,0.5]

# only VIP-related weights vary
# params = [1.00722926,  0.23913302,  2.37976463, -3.71018897,  1.49218142,
#           1.00722926,  0.23913302,  2.37976463, -7.69489824, -1.95185175, 0.5]

# all parameters are fixed
# params = [0.95948705 ,  0.22874792,   4.94801569, -10.  ,         0.58251944,
#         0.95948705,   0.22874792,   4.94801569, -10. ,          0.58251944, 0.5]

groups = [1.0, 1-params[10]]  # value of DART effect x

# connction weights for each state
weights = {
    states[0]:{'WEE': params[0], 'WES': params[1], 
            'WSE': params[2], 'WSVE': params[3], 'WSS': params[4]},
    states[1]:{'WEE': params[5], 'WES': params[6], 
           'WSE': params[7], 'WSVE': params[8], 'WSS': params[9]}
}  


In [None]:
# Simulation

# Initialize the firing rates
fr_simulation = np.zeros((len(neuron_types), len(states), 
                          len(contrasts), len(groups), T), dtype=np.float128)

# Simulate the firing rates
for s, state in enumerate(states):  
    for c, contrast in enumerate(contrasts):
        for g, group in enumerate(groups):
            for t in range(T):
                if t < stim_start or t > stim_end:
                    I_ex_temp = 0
                    I_sx_temp = -1
                    
                else:
                    I_ex_temp = np.sqrt(rEdatacontrol[s, c] / gain_e) - \
                        weights[state]['WEE'] * rEdatacontrol[s, c] + \
                        weights[state]['WES'] * rSdatacontrol[s, c]
                    I_sx_temp = np.sqrt(rSdatacontrol[s, c] / gain_s) - \
                        (weights[state]['WSE'] - weights[state]['WSVE']) * rEdatacontrol[s, c] - \
                        weights[state]['WSS'] * rSdatacontrol[s, c]
                if t < stim_start:
                    re = re_start
                    rs = rs_start
                
                z_e = weights[state]['WEE']*re -\
                    weights[state]['WES']*rs + I_ex_temp
                z_s = (weights[state]['WSE']* group - weights[state]['WSVE'])*re  +\
                    weights[state]['WSS']*rs + I_sx_temp

                z_e = z_e * (z_e > 0)
                z_s = z_s * (z_s > 0)

                re = re + (-re + gain_e * z_e ** 2)/tau_e * dt
                rs = rs + (-rs + gain_s * z_s ** 2)/tau_s * dt

                fr_simulation[:, s, c, g, t] = [re, rs]
        
        # Plot the results of 2 cell types in different conditions
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        for g, group in enumerate(groups):
            plt.plot(np.arange(T), fr_simulation[0, s, c, g, :], 
                     label='Control' if group==1.0 else f'x = {1-group:.1f}',
                     color='black' if group==1.0 else 'blue')
        
        # plt.title(neuron_types[0])
        plt.xlim((0, T))
        plt.axvline(stim_start, color='black', ls='--')
        plt.ylim((0, 0.15))
        plt.ylim((0, 0.2))
        plt.ylim((0, 0.4))
        plt.xticks([])  
        plt.yticks([])
        plt.axvline(stim_start, ls='--')
        plt.axvline(stim_end, ls='--')
        # plt.xlabel('Time')
        # plt.ylabel('Rates')
        # plt.legend(title='Group', loc='upper left', fontsize=15)

        plt.subplot(1, 2, 2)
        for g, group in enumerate(groups):
            plt.plot(np.arange(T), fr_simulation[1, s, c, g, :],
                     label='Control' if group==1.0 else f'x={1-group:.1f}',
                     color='black' if group==1.0 else 'blue')
        # plt.title(neuron_types[1])
        plt.xlim((0, T))
        plt.axvline(stim_start, color='black', ls='--')
        plt.ylim((0, 0.15))
        plt.ylim((0, 0.2))
        plt.ylim((0, 0.4))
        # plt.ylim((0, 40))
        plt.xticks([])
        plt.yticks([])
        plt.axvline(stim_start, ls='--')
        plt.axvline(stim_end, ls='--')
        # plt.xlabel('Time')
        # plt.ylabel('Rates')
        # plt.legend(title='Group', loc='upper left', fontsize=15)

        # plt.suptitle(f'{state} State, {contrast} Contrast')
        plt.tight_layout()
        # plt.savefig(f'{state}_{contrast}.pdf', format='pdf')
        plt.show()


In [None]:
# Calculate mean firing rates in the steady state
# In the range of 200 ms before the stimulus ends 
# to the end of the stimulus

fr_steady = np.mean(fr_simulation[:, :, :, :, stim_end-500:stim_end-1], axis=4)
fr_steady

In [None]:
# Plot firing rates in steady state for different conditions with empirical data

def plot_data(subplot_position, fr_avg, fr_se, 
              fr_steady_index, neuron_type):
    plt.subplot(2, 2, subplot_position)
    for i in range(2):
        y_values = [x[i] for x in fr_avg]
        y_errors = [x[i] for x in fr_se]
        plt.errorbar(contrast_values, y_values, yerr=y_errors, label=labels[i],
                    color=colors[i], capsize=5, zorder=1)
    for g, group in enumerate(groups):
        plt.plot(contrast_values, fr_steady[fr_steady_index].T[g],
                 label='Control' if group==1.0 else f'x = {1-group:.1f}',
                 color='orange', zorder=2)
    plt.title(neuron_type)
    plt.xticks(contrast_values, contrasts)
    plt.ylim((0, 0.3))
    plt.yticks(np.arange(0, 0.4, 0.1))
    plt.xlabel('Contrast')
    plt.ylabel('dF/F')

plt.figure(figsize=(10, 10))
colors = ['black', (0, 0, 1)]
labels = ['Control', 'DART']

# The order is stationary (Pyr, SST), running(Pyr, SST)
exp_avg = [[(0.1012,0.1190),(0.1032,0.1304), (0.1254,0.1593)],
           [[0.0539, 0.0246], [0.0728,0.0463], [0.1014, 0.0883]],
           [(0.1824,0.2021),(0.1939,0.2235), (0.2282,0.2665)],
           [[0.1285, 0.1116], [0.1688,0.1717], [0.2091, 0.2423]]]
exp_se = [[[0.005528, 0.007429], [0.006080, 0.007788], [0.007634, 0.009315]],
          [[0.008280, 0.004273], [0.009692, 0.006342], [0.012741, 0.010546]],
          [[0.009524, 0.011529], [0.009672 , 0.011742], [0.011056, 0.013709]],
          [[0.013620, 0.013669], [0.014691, 0.015772], [0.019223, 0.021288]]]

# Loop for 2 states and 2 neuron types
for s, state in enumerate(states):
    for n, neuron_type in enumerate(neuron_types):
        plot_data(2*s+n+1, exp_avg[2*s+n], exp_se[2*s+n], (n, s), neuron_type)

plt.tight_layout()
# uncomment the following line to save the figure
# plt.savefig('Empirical and simulated responses-all changing.pdf')
plt.show()