In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from brian2 import *
import collections

In [None]:
# Define parameters of the model and the default values for the mean field parameters 
DynMemory = collections.namedtuple(
    typename='DynMemory',
    field_names='C ,g_L ,E_L ,v_th ,T_w ,a  ,b  ,E_e ,Q_e ,T_s, Eta ,Delta ,v_reset, v_peak')

default_adj_params = DynMemory(
    C=200, g_L=10, E_L=-62, v_th=-55, T_w=20, a=4, b=20, E_e=0, Q_e=4, T_s=6, Eta=10, Delta=.5, v_reset=-70., v_peak=-10.)

# Define the adjusted mean field equations for a network of eQIF neurons
# The correction term of the current is discussed in SuppmentaL Material section of the paper
def adjusted_dmMF(t, y, p:DynMemory, I_func):
    r, v, w, s = y
    C ,g_L ,E_L ,v_th ,T_w ,a  ,b  ,E_e ,Q_e ,T_s, Eta ,Delta, v_reset, v_peak = p
    I0 = I_func(t)

    alpha = g_L*(v_th + E_L) + s
    beta = g_L*v_th*E_L + s*E_e + w + I0 + Eta

    mu = 4*beta/g_L - (alpha/g_L)**2

    if mu >= 0:
        gamma = np.arctan((2*v_peak - alpha/g_L)/np.sqrt(mu)) - np.arctan((2*v_reset - alpha/g_L)/np.sqrt(mu))
        I_adj = g_L*mu*np.pi**2 / (4*gamma**2) + alpha**2 / (4*g_L) -g_L*E_L*v_th - s*E_e - w - Eta 
    else:
        I_adj = I0

    dr = r*(g_L*(2*v-v_th-E_L) - s) / C + Delta*g_L/(np.pi*C**2)
    dv = (g_L * (E_L - v) * (v_th - v) + w + I_adj + s*(E_e - v) + Eta - (np.pi*C*r)**2/g_L) / C
    dw = (a * (v - E_L) - w) / T_w + b*r
    ds = -s/T_s + Q_e*r
    
    return np.array([dr, dv, dw, ds])

In [None]:
# Define time parameters for integration
dt = 0.01 
duration_mf = 1500 
dtime = dt * ms
duration = duration_mf * ms
defaultclock.dt = dtime

# Set parameters for the mean-field model 
adj_params = default_adj_params

# Define parameters with units for Brian2 simulation (same as in mean-field)
N=5000

C = adj_params.C * pF           
g_L = adj_params.g_L * nS/mV    
E_L = adj_params.E_L * mV         
V_T = adj_params.v_th * mV        
tau_w = adj_params.T_w * ms       
a = adj_params.a * nS             
b = adj_params.b * pA            
V_reset = adj_params.v_reset * mV     
V_peak = adj_params.v_peak * mV        
Ee = adj_params.E_e * mV
Qe = adj_params.Q_e/N * nS
Tsyn = adj_params.T_s * ms

In [None]:
# Define time-varying external current for mean-field simualtion
AmpStep = 60
BaseI = 90
Pert = 10
def input_current(t):
      # Definition of the external time-varying current
      if t>105 and t<250:
          return BaseI+AmpStep  # 100-600 ms: 130 pA
      elif t>650 and t<720:
          return BaseI-Pert
      elif t>780 and t<850:
          return BaseI+Pert
      elif t>1000 and t<1150:
          return BaseI-AmpStep
      elif t>1250 and t<1320:
          return BaseI+Pert
      elif t>1350 and t<1420:
          return BaseI-Pert
      else:
          return BaseI

time_array = np.arange(0, duration_mf, dt)
current_array = []
for step in time_array:
    current_array.append(input_current(step))

# Define time-varying external current for Brian2 simulation using TimedArray
time_steps = int(duration / dtime)  
current_array = np.full(time_steps, BaseI)  
current_array[int(100*ms/dtime):int(250*ms/dtime)] = BaseI+AmpStep  
current_array[int(650*ms/dtime):int(720*ms/dtime)] = BaseI-Pert
current_array[int(780*ms/dtime):int(850*ms/dtime)] = BaseI+Pert  
current_array[int(1000*ms/dtime):int(1150*ms/dtime)] = BaseI-AmpStep
current_array[int(1250*ms/dtime):int(1320*ms/dtime)] = BaseI+Pert  
current_array[int(1350*ms/dtime):int(1420*ms/dtime)] = BaseI-Pert
I_t = TimedArray(current_array * pA, dt=dtime)  

In [None]:
# Simulate the adjusted mean-field model with random initial conditions
rnd = np.random.default_rng()
y0 = [np.round(rnd.uniform(.001, .003), 3), np.round(rnd.uniform(-70., -65.), 3),
      np.round(rnd.uniform(1., 5.), 3), np.round(rnd.uniform(.001, .005), 3)]
adj_sim = solve_ivp(adjusted_dmMF, (0, duration_mf), y0, args=(adj_params, lambda t: input_current(t)), max_step=dt)

In [None]:
# Define the spiking neural network in Brian2
start_scope()

# Set shared variable for synaptic conductance
Gsyn = NeuronGroup(1, '''
dGesyn/dt = -Gesyn/Tsyn : siemens
''', method='rk4')
Gsyn.Gesyn = 0*siemens

# eQIF model equations with synaptic input and external current
adj_eqs = '''
dV/dt = (g_L * (E_L - V) * (V_T - V) + w + I_ext + n - Gesyn*(V-Ee)) / C : volt
dw/dt = (a * (V - E_L) - w) / tau_w : amp
I_ext = I_t(t) : amp
n : amp
Gesyn : siemens (linked)
'''

# Set neuron group
G = NeuronGroup(N, adj_eqs, threshold='V > V_peak', reset='V = V_reset; w += b', method='rk4')

# Initialize variables for quenched heterogeneity
e = adj_params.Eta
d = adj_params.Delta
x = np.linspace(0+1/N,1-1/N,N)
rng = np.random.default_rng()
adj_etas = e + d*np.tan(np.pi*(x-0.5))
rng.shuffle(adj_etas)

# Initialize variables with some randomness 
Vinit = np.round(rnd.uniform(-70., -60., size=N), 3)
Winit = np.round(rnd.uniform(1., 5., size=N), 3)
G.n = adj_etas * pA
G.V =  Vinit * mV  
G.w = Winit * pA  
G.Gesyn = linked_var(Gsyn, 'Gesyn')

# Connect neurons
S = Synapses(G, Gsyn, on_pre='Gesyn_post += Qe')
S.connect()

# Monitor variables
adj_M_spike = SpikeMonitor(G)
adj_M_FR = PopulationRateMonitor(G)
# Optional: record membrane potential with a specific dt 
# instead of default one (for reduce memory usage)
sample_rate = 0.1*ms
adj_M_voltage = StateMonitor(G, 'V', record=True, dt=sample_rate)

# Continue without recording
run(duration)

In [None]:
# Safety check: plot raster to verify activity and membrane potential of a few neurons to verify dynamics
raster_activity = np.array([adj_M_spike.t/ms, adj_M_spike.i])
index_mask = (raster_activity[1] <= 1000)

plt.figure(figsize=(15, 5))
plt.plot(raster_activity[0][index_mask], raster_activity[1][index_mask], ',k')
plt.show()

plt.figure(figsize=(12, 6))
for i in range(5):
    plt.plot(adj_M_voltage.t/ms, adj_M_voltage.V[i]/mV + i*20, label=f'Neuron {i}')  # Offset for visibility
plt.xlim(50, 1100)
plt.show()

In [None]:
# Compute ISI Coefficient of Variation (CV) for each neuron within stimulation 
# peak period (50-1100 ms). Sliding windows used to capture time-varying behaviors
time_interval = 75
time_windows = np.arange(50, 1100 + time_interval, time_interval)

N = int(np.max(raster_activity[1])) + 1
cv_measure = [[] for _ in range(len(time_windows) - 1)]

for i in range(len(time_windows) - 1):
    mask = (raster_activity[0] >= time_windows[i]) & (raster_activity[0] < time_windows[i+1])

    N = int(np.max(raster_activity[1])) + 1
    results = [[] for _ in range(N)]

    for time, idx in zip(raster_activity[0][mask], raster_activity[1][mask].astype(int)):
        results[idx].append(np.round(time,2))

    results = [times for times in results if len(times) > 0]
    cvs = [np.std(np.diff(times)) / np.mean(np.diff(times)) if len(times) > 2 else np.nan for times in results]

    cv_measure[i] = cvs

# Extract mean, min, max, std for each time window and create arrays for plotting
time_bin_len = int(time_interval / (dtime/ms))
cv_means = np.concatenate([np.ones(time_bin_len) * mean 
                           for mean in [np.nanmean(cvs) for cvs in cv_measure]])
cv_mins = np.concatenate([np.ones(time_bin_len) * mn 
                          for mn in [np.nanmin(cvs) for cvs in cv_measure]])
cv_maxs = np.concatenate([np.ones(time_bin_len) * mx 
                          for mx in [np.nanmax(cvs) for cvs in cv_measure]])
cv_stds = np.concatenate([np.ones(time_bin_len) * std 
                          for std in [np.nanstd(cvs) for cvs in cv_measure]])

In [None]:
def compute_chi_sliding_window(state_monitor, time_window, f_sample, window_size=50*ms, step_size=10*ms):
    """
    In each window, compute the population mean voltage and variance, and the mean of individual variances.
    χ(t) = sqrt(Var(V_pop) / mean(Var(V_i)))

    Input:
    state_monitor : StateMonitor, Brian2 StateMonitor with voltage recordings
    window_size : time of the sliding window
    step_size : time step between successive windows
    Returns:
    times : array, center times of each window
    chi_t : array, time-resolved synchrony measure
    """
    time_mask = (state_monitor.t/ms >= time_window[0]) & (state_monitor.t/ms < time_window[1])
    V_traces = state_monitor.V[:, time_mask] / mV
    effective_times = state_monitor.t[time_mask]
    
    # Convert to indices
    window_samples = int(window_size / f_sample)
    step_samples = int(step_size / f_sample)
    n_timepoints = V_traces.shape[1]
    
    chi_t = []
    times = []

    var_V = []
    mean_var_Vi = []
    
    for start_idx in range(0, n_timepoints, step_samples):
        end_idx = start_idx + window_samples
        
        # Extract window
        V_window = V_traces[:, start_idx:end_idx]
        
        # Compute χ for this window
        V_pop_window = np.mean(V_window, axis=0)
        sigma_V_squared = np.var(V_pop_window)
        sigma_Vi_squared = np.var(V_window, axis=1)
        mean_sigma_Vi_squared = np.sum(sigma_Vi_squared)/N
        
        if mean_sigma_Vi_squared > 0:
            chi = np.sqrt(sigma_V_squared / mean_sigma_Vi_squared)
        else:
            chi = 0.0

        time = effective_times[start_idx:end_idx].mean()
        chi_t.append(chi)
        times.append(time/ms)
        var_V.append(sigma_V_squared)
        mean_var_Vi.append(mean_sigma_Vi_squared)
    
    
    return np.array(times), np.array(chi_t), np.array(var_V), np.array(mean_var_Vi)

In [None]:
# Compute time-resolved synchrony measure based on membrane potential and plot it.
# NOTE: the more fsample to store the membrane values is small the more the figure is refined
times, chi_t, var_V, mean_var_Vi = compute_chi_sliding_window(adj_M_voltage, [50, 1100], f_sample=sample_rate, window_size=10*ms, step_size=2*ms)

plt.figure(figsize=(12, 4))
ax = plt.gca()
ax.plot(adj_M_FR.t[int((50/dtime)*ms):]/ms, current_array[int((50/dtime)*ms):])
ax2 = ax.twinx()
ax2.plot(times, chi_t)
ax2.set_ylabel('χ(t)')
ax2.set_xlim(50, 1500)
plt.show()

In [None]:
# Final figure: top three panels full range, bottom two zoomed and aligned


# Grid and style
#grid_scheme = GridSpec(nrows=5, ncols=1, height_ratios=[.5, 1, 1, .5, .5], hspace=0.5)
plt.rcParams.update({
'font.size': 20,          # Controls default text size
'axes.titlesize': 20,     # Title font size
'axes.labelsize': 20,     # X/Y label font size
'xtick.labelsize': 18,    # X tick labels
'ytick.labelsize': 18,    # Y tick labels
'legend.fontsize': 14,    # Legend font size
})
fig, axs = plt.subplots(5, 1 , height_ratios=[.5, 1, 1, .5, .5], figsize=(12, 12), sharex=True)
ax0, ax1, ax2, ax3, ax4 = axs[0], axs[1], axs[2], axs[3], axs[4]


# Plot time ranges (ms)
full_start, full_end = 0, 1500
zoom_start, zoom_end = 50, 1100

# (a) Current array ---
ax0.plot(adj_M_FR.t/ms, current_array, color='purple')
ax0.set_xlim(full_start, full_end)
ax0.set_ylabel('$I_{ext}$ (pA)')
ax0.xaxis.set_visible(False)
ax0.spines['top'].set_visible(False)
ax0.spines['right'].set_visible(False)

# (b) Population firing rate
adj_M_FRsmt = adj_M_FR.smooth_rate(window='flat', width=1.01*ms)
ax1.plot(adj_M_FR.t/ms, adj_M_FRsmt/Hz, 'k', label='SNN')
ax1.plot(adj_sim.t, adj_sim.y[0]*1000, 'r', lw=2, label='MF')
ax1.set_ylabel('$FR$ (Hz)')
ax1.set_xlim(full_start, full_end)
ax1.set_ylim(0,150)
ax1.legend(loc='upper right')
ax1.xaxis.set_visible(False)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# (c) Raster plot
# (green shaded area highlight the region were the measures were computed)
index_mask = (raster_activity[1] <= 1000)
ax2.plot(raster_activity[0][index_mask], raster_activity[1][index_mask], ',k')
ax2.fill([zoom_start, zoom_start, zoom_end, zoom_end], [0, 1000, 1000, 0], color='lightgreen', alpha=0.5)
ax2.set_xlim(full_start, full_end)
ax2.set_ylabel('Neuron #')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

# (d) χ(t) synchrony measure (based on membrane potential)
ax3.plot(times, chi_t, 'k')
ax3.set_ylabel('χ(t)')
ax3.xaxis.set_visible(False)
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)

# (e) Mean inter-spike intervals Coefficient of variation (CV) across the population
cv_time = np.linspace(zoom_start, zoom_end, len(cv_means))
yerr_lower = cv_means - cv_mins
yerr_upper = cv_maxs - cv_means
ax4.errorbar(cv_time, cv_means, yerr=[yerr_lower, yerr_upper], fmt='none', ecolor='orange', alpha=0.01)
ax4.plot(cv_time, cv_means, 'k')
ax4.set_ylabel('CV')
ax4.set_xlabel('t (ms)')
ax4.spines['top'].set_visible(False)
ax4.spines['right'].set_visible(False)


# Reposition bottom axes to match the highlighted horizontal region under the full-width axes
# Align y-labels across subplots and ensure left margin is sufficient
fig.align_ylabels(axs)
fig.subplots_adjust(left=0.12)
fig.canvas.draw()
full_pos = ax1.get_position()

# Place aligned subplot letters using figure coordinates (left of y-axis labels)
# compute a common x in figure coords slightly left of the full axis left edge
x_fig = full_pos.x0 - .12
axes_list = [ax0, ax1, ax2, ax3, ax4]
letters = ['a', 'b', 'c', 'd', 'e']
for ax, letter in zip(axes_list, letters):
    pos = ax.get_position()
    y_fig = pos.y0 + pos.height * 0.98
    fig.text(x_fig, y_fig, letter, va='top', ha='left')

plt.show()