In [1]:
import numpy as np
import brainpy as bp
import brainpy.math as bm

from scipy.sparse import csr_matrix
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"  # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
import torch

# visualization
import matplotlib as mpl
import matplotlib.pyplot as plt


from sbi.inference import (
    likelihood_estimator_based_potential,
    SNLE,
    prepare_for_sbi,
    simulate_for_sbi,
    VIPosterior,
)

from scipy import sparse
# sbi
from sbi.inference import SNPE, SNRE, SNLE, prepare_for_sbi, simulate_for_sbi
from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn
from sbi import utils as utils
from sbi import analysis as analysis
from scipy.stats import kurtosis as kurt
from sbi.utils.user_input_checks import process_pytorch_prior, process_simulator

# # sbi
# from sbi import utils as utils
# from sbi import analysis as analysis
# from sbi.inference.base import infer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Body_Wall_muscle(bp.NeuGroup):
    def __init__(self, size, ECa= 60., gCa= 15.6, EK=-40., gK=34., EL=-24, gL=0.1,
                 V_th= 10., C= 22, p_max = 0.1, phi=1., phi_m = 1.2, gkr = 10., g_slo2 = 10. , g_Na = 0.01, ENa = 30, phi_n = 1.2, noise_factor = 0.01, **kwargs):
        # providing the group "size" information
        super(Body_Wall_muscle, self).__init__(size=size, **kwargs)

        # initialize parameters
        self.ECa = ECa
        self.EK = EK
        self.EL = EL
        self.ENa = ENa
        self.gCa = gCa
        self.g_Na   = g_Na
        self.gK = gK
        self.gL = gL
        self.C = C
        self.p_max = p_max
        self.V_th  = V_th
        self.noise =  noise_factor 
        self.phi_m  = phi_m
        self.phi_n  = phi_n
        self.alpha  = 43.
        self.beta   = 0.09
        self.g_slo2 = g_slo2
        self.gkr    = gkr
        self.phi    = phi

        # initialize variables
        self.V = bm.Variable(bm.random.randn(self.num) - 30.)
        self.m = bm.Variable(0.01 * bm.ones(self.num))
        self.h = bm.Variable(0.6 * bm.ones(self.num))
        self.n = bm.Variable(0.99 * bm.ones(self.num))
        self.p = bm.Variable(0.2 * bm.ones(self.num))
        self.kr = bm.Variable(0.0 * bm.ones(self.num))

        self.p_slo2 = bm.Variable(bm.zeros(self.num))
        self.Ca   = bm.Variable(bm.zeros(self.num))
        self.Ica  = bm.Variable(bm.zeros(self.num))

        self.input = bm.Variable(bm.zeros(self.num))
        self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
        self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)

        # integral functions
        self.int_V = bp.odeint(f=self.dV, method='exp_auto')
        self.int_m = bp.odeint(f=self.dm, method='exp_auto')
        self.int_h = bp.odeint(f=self.dh, method='exp_auto')
        self.int_n = bp.odeint(f=self.dn, method='exp_auto')
        self.int_p = bp.odeint(f=self.dp, method='exp_auto')
        self.int_p_slo2 = bp.odeint(f=self.dp_slo2, method='exp_auto')
        self.int_Ca = bp.odeint(f=self.dCa, method='exp_auto')
        self.int_kr = bp.odeint(f=self.dkr, method='exp_auto')

    def dV(self, V, t, m, h, n, p, p_slo2, kr, Iext):
        I_Ca = (self.gCa * m ** 2.0 * h) * (V - self.V_th - self.ECa)
        I_K = (self.gK * n ** 4.0) * (V - self.V_th - self.EK)
        I_M = (self.p_max * p) * (V - self.V_th - self.EK)
        I_slo2 = (self.g_slo2 * self.m_slo2inf(V)**3 * p_slo2) * (V -  self.EK)
        I_Na = self.g_Na * (V - self.V_th - self.ENa)
        I_kr = self.gkr *(1-kr) * self.krinf(V) *  (V - self.EK)
        I_leak = self.gL * (V - self.V_th  - self.EL)
        dVdt = (- I_Ca  - I_K - I_Na - I_slo2 - I_leak - I_kr - I_M + Iext) / self.C
        return dVdt
    
    krinf  = lambda self, V: 0.5 *(1+bm.tanh((V -  self.V_th + 42)/ 5.0))
    m_slo2inf = lambda self, V: 1/(1+bm.exp(-(V - (-33.4)) / 3.2))

    def dkr(self, kr, t, V):
        # krinf = 0.5 *(1+bm.tanh((V -  self.V_th + 42)/ 5.0))
        taumkr= 62
        dkrdt = (self.krinf(V)-kr)/taumkr
        return dkrdt

    def dp_slo2(self, p_slo2, t, Ca, V):
        C2 = self.alpha * bm.power(Ca, 2)
        C3 = C2 + self.beta
        return self.phi * (C2 / C3 - p_slo2) * C3

    def dCa(self, Ca, t, m, h, V):
        ICa = (self.gCa * m ** 2.0 * h) * (V - self.V_th - self.ECa)
        return -0.15 * ICa * 1e-4 - 0.075 * (Ca - 0.001)

    def dn(self, n, t, V):
        ninf = 0.5 * (bm.tanh((V - self.V_th +15.2)/36.22)+1)
        tau_n = 1.18+511.78/(1+bm.exp((V - self.V_th + 89.3)/21.92))
        dndt = self.phi_n * (ninf-n)/tau_n
        return dndt

    # def dm(self, m, t, V):
    #     tau_m = 61/(1+bm.exp((V - self.V_th + 81.2)/45.6)) + 22.39/(1+bm.exp(-(V - self.V_th -24.26)/22.26)) - 14.25 
    #     minf = -0.53/(1+bm.exp(-(V - self.V_th - 26)/6.4)) + 1.058/(1+bm.exp(-(V - self.V_th +8.75)/7.2655)) + 0.0095
    #     dmdt = self.phi_m * (minf-m)/tau_m
    #     return dmdt

    def dm(self, m, t, V):
        tau_m = 0.4 + .7 / (bm.exp(-(V + 5. - self.V_th) / 15.) +
                       bm.exp((V + 5. - self.V_th) / 15.))
        minf = 1. / (1 + bm.exp(-(V + 8. - self.V_th) / 8.6))
        dmdt = self.phi_m * (minf-m)/tau_m
        return dmdt

    def dh(self, h, t, V):
        # hinf = 0.435/(1+bm.exp((V  - self.V_th + 10.38)/0.5554)) + 64.045/(1+bm.exp(-(V  - self.V_th -171.5)/30.8)) + 0.1
        hinf   = 0.42 / (1. + bm.exp((V + 11. - self.V_th) / 2.)) + 0.28
        # hinf  = (1.43 / (1 + bm.exp(-(V - self.V_th + 15 - 14.9) / 12)) + 0.14) * (5.96 / (1 + bm.exp((V  - self.V_th  + 15 + 20.5) / 8.1)) + 0.6 - 0.32)
        tau_h = 24
        dhdt = (hinf-h)/tau_h
        return dhdt

    def dp(self, p, t, V):
        pinf = 1/(1+bm.exp(-(V- self.V_th +45)/10))
        tau_p = 4000/(3.38*bm.exp((V- self.V_th+45)/20)+bm.exp(-(V- self.V_th +45)/20))
        dpdt = (pinf-p)/tau_p
        return dpdt

    def update(self, tdi, x=None):
        _t, _dt = tdi.t, tdi.dt
        # compute V, m, h, n
        noise_add = self.noise * bm.random.randn(self.num) / bm.sqrt(_dt)
        V = self.int_V(self.V, _t, self.m, self.h, self.n, self.p, self.p_slo2, self.kr, self.input/0.75, dt=_dt)
        self.h.value = self.int_h(self.h, _t, self.V, dt=_dt)
        self.m.value = self.int_m(self.m, _t, self.V, dt=_dt)
        self.n.value = self.int_n(self.n, _t, self.V, dt=_dt)
        self.p.value = self.int_p(self.p, _t, self.V, dt=_dt)
        self.p_slo2.value = self.int_p_slo2(self.p_slo2, _t, self.Ca, self.V, dt=_dt)
        self.Ca.value = self.int_Ca(self.Ca, _t, self.m, self.h, self.V, dt=_dt)
        self.kr.value = self.int_kr(self.kr, _t, self.V, dt=_dt)
        self.Ica.value = (self.gCa * self.m ** 2.0 * self.h) * (self.V - self.V_th - self.ECa)

        # update the spiking state and the last spiking time
        self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
        self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)

        # update V
        self.V.value = V

        # reset the external input
        self.input[:] = 0.

In [3]:
def generate_neuron_connection_matrix(A):
    """
    生成神经元连接矩阵，其中组内连接强度由矩阵 A 控制。
    
    参数:
    A: 控制每组内连接强度的矩阵，维度为 [n, group_size-1]，n 是组的数量，group_size 是每组内神经元的数量 + 1（A 的列数 + 1）。
    
    返回:
    稀疏的神经元连接矩阵
    """
    num_groups = A.shape[0]
    group_size = A.shape[1] + 1  # A的列数 + 1 是组内神经元数量

    # 构建每个组的组内连接矩阵
    group_connections = [
        sparse.diags([A[i, :], A[i, :]], [-1, 1], shape=(group_size, group_size))
        for i in range(num_groups)
    ]
    
    # 使用稀疏矩阵的 block_diag 函数将每个组的连接矩阵沿对角线拼接
    connection_matrix = sparse.block_diag(group_connections, format='csr')
    
    return connection_matrix

In [4]:
# A = np.array([
#     [0.5, 0.3, 0.7],
#     [0.6, 0.2, 0.8]
# ])
# matrix = generate_neuron_connection_matrix(A)
# print(matrix)
# conn = bp.conn.SparseMatConn(matrix)
# print(conn.value)
# data = matrix.data
# print(data)

In [5]:
conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)
sparse_mat = csr_matrix(conn_mat)
print(sparse_mat)

  (0, 0)	True
  (0, 1)	True
  (0, 2)	True
  (1, 2)	True
  (2, 2)	True
  (3, 0)	True
  (3, 1)	True
  (4, 0)	True
  (4, 2)	True


In [6]:
sparse_mat[0]

<1x3 sparse matrix of type '<class 'numpy.bool_'>'
	with 3 stored elements in Compressed Sparse Row format>

In [7]:
duration  = 1000.
small_size = 6
groups     = 1000.

In [8]:
class MuscleNet(bp.DynamicalSystemNS):
    def __init__(self , conn_matrix, net_size, **kwargs):
        super().__init__()
        g_max = conn_matrix.data
        self.N = Body_Wall_muscle(size= net_size)
        # self.N2N = bp.synapses.GapJunction(pre = self.N, post = self.N, conn = bp.connect.All2All(), g_max = conn_matrix)
        self.N2N = bp.synapses.GapJunction(pre = self.N, post = self.N, conn = bp.conn.SparseMatConn(conn_matrix), comp_method ='sparse' ,g_max = g_max)
        
    def update(self):
        t = bp.share.load('t')
        dt = bp.share.load('dt')
        self.N()
        self.N2N() 

In [9]:
# A = np.array([
#     [0.5, 0.3, 0.7],
#     [0.6, 0.2, 0.8]
# ])
# connection_matrix = generate_neuron_connection_matrix(A)
# net_size = A.shape[0] * (A.shape[1] + 1)

# conn = bp.conn.SparseMatConn(connection_matrix)
# mat = conn.require("conn_mat")
# print(mat)
# muscle_net = MuscleNet(connection_matrix, net_size)

In [10]:
def run_Net_model(params):
    params   = bm.asarray(params)
    net_size = params.shape[0] *(params.shape[1] + 1)
    connection_matrix = generate_neuron_connection_matrix(params)
    # w_real = bm.array(connection_matrix)
    net = MuscleNet(
        conn_matrix =  connection_matrix, net_size = net_size
    )    
    runner = bp.DSRunner(
        net, 
        monitors=['N.spike', 'N.V'], 
        inputs= ['N.input', bm.tile(bm.linspace(30, 0, 6), params.shape[0])],  
        progress_bar=False
    )
    runner.run(duration)
    return dict(t=runner.mon['ts'], spikes=runner.mon['N.spike'].T, data=runner.mon['N.V'].T, dt = 0.1)

In [11]:
time_to_first_spike = lambda x: np.where(x)[0][0] if np.any(x) else 0
def compute_mean_isi(neuron_spikes):
    spike_times = np.where(neuron_spikes)[0]
    if len(spike_times) > 1:
        intervals = np.diff(spike_times)
        return np.mean(intervals)
    else:
        return 0.

def calculate_summary_statistics(x):
    v =  np.array(x["data"])
    t = x["t"]
    dt = x["dt"]
    # Mean and standard deviation during stimulation
    v_stim = v
    mean_v_stim = np.mean(v_stim, axis=1)
    std_v_stim  = np.std(v_stim, axis=1)
    max_v_stim  = np.max(v_stim, axis=1) / 10.0

    # spike calculation
    spike_counts = x['spikes'].sum(axis=1)
    mean_isi_values = np.apply_along_axis(compute_mean_isi, axis=1, arr=x['spikes']) * dt
    first_spike_times = np.apply_along_axis(time_to_first_spike, axis=1, arr=x['spikes']) * dt

    sum_stats_vec = np.column_stack((
        spike_counts,
        mean_isi_values/20.,
        first_spike_times/20.,
        mean_v_stim,
        std_v_stim,
        max_v_stim
    ))
    sum_stats_vec_mean = sum_stats_vec.reshape(-1, small_size, sum_stats_vec.shape[1]).mean(axis=1)
    return sum_stats_vec_mean

def simulation_wrapper(params):
    """
    Returns summary statistics from conductance values in `params`.
    Summarizes the output of the simulation and converts it to `torch.Tensor`.
    """
    obs = run_Net_model(params)
    summstats = torch.as_tensor(calculate_summary_statistics(obs))
    return summstats.to(torch.float32)

In [12]:
# prior_min = [5. , 5.,  10., 8.,  5.]
# prior_max = [20.,  20.0, 18., 16., 14.,]
# prior_min = [10. , 10.,  10., 8.,  5.]
# prior_max = [20.,  20.0, 16., 15., 15.,]

prior_min = [10. , 10.,  10., 8.,  5.]
prior_max = [25.,  23.0, 18., 15., 15.,]

# prior_min = [10.5 , 20.,   1e-4]
# prior_max = [50.0, 55.0,   30.,]
prior = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min),
                                        high=torch.as_tensor(prior_max))

In [13]:
# true_params = np.array([[15., 14., 13., 12.,11.]])
true_params = np.array([[20., 17., 15., 11., 9.]])
# true_params = np.array([[19.8, 37., 22., 10., 10.]])
# true_params = np.array([[15.6, 34., 10.]])
true_params.shape[0]

1

In [14]:
# Simulate samples from the prior distribution
num_rounds = 3
posteriors = []
proposal = prior

# real_group = 1
# conn_real_matrix = generate_neuron_connection_matrix(n_groups= real_group , group_size=small_size)
# w_real = bm.array(conn_real_matrix)

true_data = run_Net_model(true_params)
xo = calculate_summary_statistics(true_data)
# xo = torch.tensor(xo, dtype=torch.float32)
# x_aike =[[5.       ,  103.625   ,    70.2    ,    -29.00938034  , 0.21364328,
#   -11.58708191 , 17.44003296  , 3.21818399  , 3.58649588]]
xo = torch.tensor(xo, dtype=torch.float32)

In [15]:
print(xo)

tensor([[  4.0000,   5.0561,  34.1375, -25.1901,  10.7755,   3.0057]])


In [16]:
print(true_data['data'].shape)
print(xo.shape)

inference = SNLE(prior)
for _ in range(num_rounds):
    theta = proposal.sample((8_000,))
    stats = simulation_wrapper(theta)
    print('done.')
    likelihood_estimator = inference.append_simulations(
        theta, stats,
    ).train()
    print('Training inference network... ')
    potential_fn, theta_transform = likelihood_estimator_based_potential(
        likelihood_estimator, prior, xo
    )
    posterior = VIPosterior(
        potential_fn, prior, "maf", theta_transform, vi_method="fKL",
    ).train()
    posteriors.append(posterior)
    proposal = posterior

(6, 10000)
torch.Size([1, 6])
done.
 Neural network successfully converged after 158 epochs.Training inference network... 


Loss: 30.84Std: 0.06:  40%|███▉      | 790/2000 [00:15<00:24, 49.60it/s]              



Converged with loss: 30.84
Quality Score: 0.264 	 Good: Smaller than 0.5  Bad: Larger than 1.0 	         NOTE: Less sensitive to mode collapse.
done.
 Neural network successfully converged after 127 epochs.Training inference network... 


Loss: 33.59Std: 0.2:  24%|██▎       | 473/2000 [00:09<00:32, 47.56it/s]               



Converged with loss: 33.59
Quality Score: 0.181 	 Good: Smaller than 0.5  Bad: Larger than 1.0 	         NOTE: Less sensitive to mode collapse.
done.
 Neural network successfully converged after 33 epochs.Training inference network... 


Loss: 32.34Std: 0.13:  72%|███████▏  | 1445/2000 [00:29<00:11, 49.14it/s]             



Converged with loss: 32.34
Quality Score: 0.282 	 Good: Smaller than 0.5  Bad: Larger than 1.0 	         NOTE: Less sensitive to mode collapse.


In [18]:
# np.savez('Muscle_Net.npz', samples=samples, true_params=true_params, true_data=true_data, xo=xo)
samples = posterior.sample((3000,), x=xo, show_progress_bars=False)
np.savez('data/Muscle_Net_4.npz', samples=samples, true_params=true_params, true_data=true_data, xo=xo)