In [1]:
import numpy as np
from scipy.integrate import odeint
import plotly.graph_objects as go
from typing import Callable

In [2]:
class Synapse:
    def __init__(self, L_in, L_out, V, 
                 K_d=None, gamma=None, N=None, V_eq=None, 
                 spiking_par=None, spiking_dist=None,
                 mode='A', max_t=60000):
        self.L_in = L_in
        self.L_out = L_out
        self.V = V
        self.K_d = K_d
        self.gamma = gamma
        self.N = N
        self.V_eq = V_eq
        self.mode = mode
        self.max_t = max_t
        self.spiking_par = spiking_par
        self.spiking_dist = spiking_dist
        self.mode_init()
    
    def mode_init(self):
        if self.mode == 'A':
            self.K_d = 128 * 1e-9
            self.gamma = 8 * 1e-3
            self.V_eq = -65
            self.spiking_dist = np.random.poisson
        elif self.mode == 'B':
            self.K_d = 500 * 1e-9
            self.gamma = 10 * 1e-3
            self.V_eq = 55
            self.spiking_dist = np.random.randint
    
    def simulate(self):
        # all timesteps
        self.times = np.arange(0, self.max_t, 1)
        spiking_intervals = self.spiking_dist(*self.spiking_par, size=self.max_t)
        spiking_times = np.cumsum(spiking_intervals)
        self.spiking_times = spiking_times[spiking_times < self.max_t]
        
        self.times2ind = dict(zip(self.times, range(len(self.times))))
        
        # all L values
        def next_L(L_prev, t):
            is_spiking = t in self.spiking_times
            L_next = L_prev + self.L_in * is_spiking - self.L_out
            L_next = max(0, min(L_next, self.V))
            return L_next
    
        self.Ls = np.zeros((len(self.times), ))
        self.Ls[0] = next_L(0, self.times[0])
        for i in range(len(self.times) - 1):
            self.Ls[i + 1] = next_L(self.Ls[i], self.times[i + 1])
        
        l = self.Ls / self.V
        l_div_K_d = l / self.K_d
        self.P_bounds = l_div_K_d / (1 + l_div_K_d)
        
        def I_synapse(t, V_m):
            t_ind = self.times2ind[int(t)]
            return self.N * self.P_bounds[t_ind] * (V_m - self.V_eq) * self.gamma
        return I_synapse


In [28]:
class HHH:
    def __init__(self, L_in, L_out, V, 
                 N_A, N_B, spiking_par_A, spiking_par_B,
                 max_t=60000,
                 g_L=3e-1, g_K=36, g_Na=120, 
                 E_L=-54.387, E_K=-77, E_Na=50, 
                 C_m=1):
        self.g_L = g_L
        self.g_K = g_K
        self.g_Na = g_Na
        self.E_L = E_L
        self.E_K = E_K
        self.E_Na = E_Na
        self.C_m = C_m
        self.max_t = max_t
        
        self.synapse_A = Synapse(L_in = L_in, L_out=L_out, V=V, N=N_A, 
                                 spiking_par=spiking_par_A, max_t=max_t, mode='A')
        self.synapse_B = Synapse(L_in = L_in, L_out=L_out, V=V, N=N_B, 
                                 spiking_par=spiking_par_B, max_t=max_t, mode='B')
        
        # Initial condition on y 
        self.y_0 = [-65, 0, 0, 0]
        for i, mode in enumerate(['n', 'm', 'h']):
            self.y_0[i + 1] = (self.alpha(self.y_0[0], mode) / 
                               (self.alpha(self.y_0[0], mode) + self.beta(self.y_0[0], mode)))

    # alpha and beta
    @staticmethod
    def alpha(V_m, mode):
        if mode == 'n':
            return 1e-2 * (V_m + 55) / (1 - np.exp(-0.1 * (V_m + 55)))
        elif mode == 'm':
            return 1e-1 * (V_m + 40) / (1 - np.exp(-0.1 * (V_m + 40)))
        elif mode == 'h':
            return 7e-2 * np.exp(-5e-2 * (V_m + 65))
        else:
            raise ValueError("mode should be 'm', 'n' or 'h'")
    
    @staticmethod
    def beta(V_m, mode):
        if mode == 'n':
            return 0.125 * np.exp(0.0125 * (V_m + 65))
        elif mode == 'm':
            return 4 * np.exp(-0.0556 * (V_m + 65))
        elif mode == 'h':
            return 1 / (1 + np.exp(-0.1 * (V_m + 35)))
        else:
            raise ValueError("mode should be 'm', 'n' or 'h'")

    def I_K(self, n, V_m):
        return self.g_K * n ** 4 * (V_m - self.E_K)
    
    def I_Na(self, m, h, V_m):
        return self.g_Na * m ** 3 * h * (V_m - self.E_Na) 
    
    # function which calculates derivatives
    def ders(self, y, t):
        if t > 60000:
            print(t)
        val = dict(zip(['V_m', 'n', 'm', 'h'], y))
        result = [0.] * 4
        for i, mode in enumerate(['n', 'm', 'h']):
            result[i + 1] = (self.alpha(val['V_m'], mode) * (1 - val[mode]) -
                             self.beta(val['V_m'], mode) * val[mode])
        result[0] = 1/self.C_m * (- self.I_K(val['n'], val['V_m'])
                                  - self.I_Na(val['m'], val['h'], val['V_m'])
                                  - (self.I_A(t, val['V_m']) + self.I_B(t, val['V_m'])))
        return result
    
    def fit(self):
        self.I_A = self.synapse_A.simulate()
        self.I_B = self.synapse_B.simulate()
        self.t = self.synapse_A.times
        self.solution = odeint(self.ders, self.y_0, self.t)
    
    def plot_V(self, min_ind, max_ind):
        fig = go.Figure(data=go.Scatter(x=self.t[min_ind:max_ind], 
                                        y=self.solution[min_ind:max_ind, 0]))
        fig.update_layout(
            xaxis_title="t (ms)",
            yaxis_title="V",
            font=dict(
                family="Courier New, monospace",
                size=18)
        )
        fig.show()
    
    def plot_n_m_h(self, min_ind, max_ind):
        fig = go.Figure()

        fig.add_trace(go.Scatter(x=self.t[min_ind:max_ind], 
                                 y=self.solution[min_ind:max_ind, 1], name="n"))
        fig.add_trace(go.Scatter(x=self.t[min_ind:max_ind], 
                                 y=self.solution[min_ind:max_ind, 2], name="m"))
        fig.add_trace(go.Scatter(x=self.t[min_ind:max_ind], 
                                 y=self.solution[min_ind:max_ind, 3], name="h"))
        fig.update_layout(
            xaxis_title="t (ms)",
            yaxis_title="value",
            font=dict(
                family="Courier New, monospace",
                size=18)
        )
        fig.show()
    
    def plot_I(self, min_ind, max_ind):
        I_K = self.I_K(self.solution[min_ind:max_ind, 1], 
                       self.solution[min_ind:max_ind, 0])
        fig = go.Figure(data=go.Scatter(x=self.t[min_ind:max_ind], y=I_K))
        fig.update_layout(
            xaxis_title="t (ms)",
            yaxis_title="I_K",
            font=dict(
                family="Courier New, monospace",
                size=18)
        )
        fig.show()
        
        I_Na = self.I_Na(self.solution[min_ind:max_ind, 2], 
                         self.solution[min_ind:max_ind, 3], 
                         self.solution[min_ind:max_ind, 0])
        I_Na += np.vectorize(self.I_B)(self.t[min_ind:max_ind], self.solution[min_ind:max_ind, 0])
        fig = go.Figure(data=go.Scatter(x=self.t[min_ind:max_ind], y=I_Na))
        fig.update_layout(
            xaxis_title="t (ms)",
            yaxis_title="I_Na",
            font=dict(
                family="Courier New, monospace",
                size=18)
        )
        fig.show()
        
        I_Cl = np.vectorize(self.I_A)(self.t[min_ind:max_ind], self.solution[min_ind:max_ind, 0])
        fig = go.Figure(data=go.Scatter(x=self.t[min_ind:max_ind], y=I_Cl))
        fig.update_layout(
            xaxis_title="t (ms)",
            yaxis_title="I_Cl",
            font=dict(
                family="Courier New, monospace",
                size=18)
        )
        fig.show()

    def plot_phases(self, min_ind, max_ind):
        data = go.Scatter3d(x=self.t[min_ind:max_ind], y=self.solution[min_ind:max_ind, 0], 
                            z=self.solution[min_ind:max_ind, 1], mode='lines', 
                            line=dict(width=3, color='blue'))
        layout = go.Layout(
            scene=dict(
                xaxis=dict(title='t (ms)'),
                yaxis=dict(title='V'),
                zaxis=dict(title='n')
            ), title="V and n")
        fig = go.Figure(data=data, layout=layout)
        fig.show()


        data = go.Scatter3d(x=self.t[min_ind:max_ind], y=self.solution[min_ind:max_ind, 0], 
                            z=self.solution[min_ind:max_ind, 2], mode='lines',
                            line=dict(width=3, color='red'))
        layout = go.Layout(
            scene=dict(
                xaxis=dict(title='t (ms)'),
                yaxis=dict(title='V'),
                zaxis=dict(title='m')
            ), title="V and m")
        fig = go.Figure(data=data, layout=layout)
        fig.show()


        data = go.Scatter3d(x=self.t[min_ind:max_ind], y=self.solution[min_ind:max_ind, 0], 
                            z=self.solution[min_ind:max_ind, 3], mode='lines',
                            line=dict(width=3, color='green'))
        layout = go.Layout(
            scene=dict(
                xaxis=dict(title='t'),
                yaxis=dict(title='V'),
                zaxis=dict(title='h')
            ), title="V and h")
        fig = go.Figure(data=data, layout=layout)
        fig.show()
    
    def plot_frequency(self, treshold=0):
        V = self.solution[:, 0]
        freqs = np.zeros((self.max_t // 1000, ))
        for i in range(len(self.t)):
            freqs[i // 1000] += V[i] > treshold and (i == 0 or V[i - 1] <= treshold)
        fig = go.Figure(data=go.Scatter(x=list(range(len(freqs))), 
                                        y=freqs))
        fig.update_layout(
            xaxis_title="t (s)",
            yaxis_title="Freq (Hz)",
            font=dict(
                family="Courier New, monospace",
                size=18)
        )
        fig.show()

In [43]:
model = HHH(L_in=500, L_out=100, V=500, N_A=100, N_B=100, spiking_par_A=[100], spiking_par_B=[0, 200], max_t=60000)

model.fit()

In [44]:
model.plot_frequency()

In [47]:
model.plot_V(0, 3000)

In [48]:
model.plot_V(0, 1000)

Нет такого, что не спайкается больше 10с подряд.

In [49]:
model.plot_n_m_h(0, 3000)

In [50]:
model.plot_phases(0, 3000)

In [51]:
model.plot_I(0, 3000)