In [1]:
###########IMPORTS############

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
###########INTEGRATOR FRAMEWORK############

#1 Constraint Checks

def _check_input_types(t, y0): # Ensure input is Correct
    if not (y0.dtype.is_floating or y0.dtype.is_complex):
        raise TypeError('`y0` must have a floating point or complex floating point dtype')
    if not t.dtype.is_floating:
        raise TypeError('`t` must have a floating point dtype')
        
def _assert_increasing(t): # Check Time is Monotonous
    assert_increasing = control_flow_ops.Assert(math_ops.reduce_all(t[1:] > t[:-1]), ['`t` must be monotonic increasing'])
    return ops.control_dependencies([assert_increasing])

#2 Integrator Class

class _Integrator():
    
    def integrate(self, evol_func, y0, time_grid): # iterator
        time_delta_grid = time_grid[1:] - time_grid[:-1]
        scan_func = self._make_scan_func(evol_func)
        y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),y0)
        return array_ops.concat([[y0], y_grid], axis=0)
    
    def _make_scan_func(self, evol_func): # stepper function
        
        def scan_func(y, t_dt): 
            if n_>0:
                t,dt = t_dt
                
                dy = self._step_func(evol_func, t, dt, y)
                dy = math_ops.cast(dy, dtype=y.dtype)
                out = y + dy
                
                ## Operate on non-integral
                
                ft = y[-n_:]
                
                l = tf.zeros(tf.shape(ft),dtype=ft.dtype)
                l_ = t-ft
                
                z = tf.less(y[:n_],F_b)
                z_ = tf.greater_equal(out[:n_],F_b)
                
                df = tf.where(tf.logical_and(z,z_),l_,l)
                
                ft_ = ft+df
                
                return tf.concat([out[:-n_],ft_],0)

            else:
                t, dt = t_dt
                dy = self._step_func(evol_func, t, dt, y)
                dy = math_ops.cast(dy, dtype=y.dtype)
                return y + dy
        
        return scan_func

    def _step_func(self, evol_func, t, dt, y):
        k1 = evol_func(y, t)
        half_step = t + dt / 2
        dt_cast = math_ops.cast(dt, y.dtype)

        k2 = evol_func(y + dt_cast * k1 / 2, half_step)
        k3 = evol_func(y + dt_cast * k2 / 2, half_step)
        k4 = evol_func(y + dt_cast * k3, t + dt)
        return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)

#3 Integral Caller

def odeint_fixed(func, y0, t):
    t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
    y0 = ops.convert_to_tensor(y0, name='y0')
    _check_input_types(t, y0)

    with _assert_increasing(t):
        return _Integrator().integrate(func, y0, t)

In [197]:
#### ALL EQUATIONS AND PARAMETERS TAKEN FROM Tiffany Kee et. al. ####
#### Values for PNs are used here ####

T = 26                               # Temperature

n_n = 4                              # number of neurons

p_n = 2                              # number of PNs
l_n = 2                              # number of LNs

t = np.arange(0.0, 500, 0.01)        # duration of simulation

C_m  = [1.0]*n_n                     # n_n x 1 vector for capacitance

g_Na = [100.0]*n_n                    # n_n x 1 vector for sodium conductance
g_K  = [10.0]*n_n                    # n_n x 1 vector for potassium conductance
g_L  = [0.15]*n_n                    # n_n x 1 vector for leak conductance
g_KL  = [0.05]*n_n                   # n_n x 1 vector for K leak conductance
g_A  = [10.0]*n_n                    # n_n x 1 vector for Transient K conductance

E_Na = [50.0]*n_n                    # n_n x 1 vector for Na Potential
E_K  = [-95.0]*n_n                   # n_n x 1 vector for K Potential
E_L  = [-55.0]*n_n                   # n_n x 1 vector for Leak Potential
E_KL  = [-95.0]*n_n                  # n_n x 1 vector for K Leak Potential
E_A  = [-95.0]*n_n                   # n_n x 1 vector for Transient K Potential
F_b = [0.0]*n_n                      # n_n x 1 vector for fire potential

inp = [0.0,15.0,0.0]              # External Current Inputs

# ACETYLCHOLINE

ach_mat = np.array([[0.0,0.0,0.0],
                    [0.0,0.0,0.0],
                    [0.0,1.0,0.0]])

n_syn_ach = int(np.sum(ach_mat))     # Number of Acetylcholine (Ach) Synapses 
alp_ach = [10.0]*n_syn_ach           # Alpha for Ach Synapse
bet_ach = [0.2]*n_syn_ach            # Beta for Ach Synapse
t_max = 0.3                          # Maximum Time for Synapse
t_delay = 0                          # Axonal Transmission Delay
A = [0.5]*n_n                        # Synaptic Response Strength
g_ach = [1.0]*n_n                    # Ach Conductance
E_ach = [0.0]*n_n                    # Ach Potential

# FAST GABA

fgaba_mat = np.array([[0.0,0.0,0.0],
                      [0.0,0.0,1.0],
                      [0.0,0.0,0.0]])

n_syn_fgaba = int(np.sum(fgaba_mat)) # Number of Fast GABA (fGABA) Synapses
alp_fgaba = [10.0]*n_syn_fgaba       # Alpha for fGABA Synapse
bet_fgaba = [0.16]*n_syn_fgaba       # Beta for fGABA Synapse
V0 = [-20.0]*n_n                     # Decay Potential
sigma = [1.5]*n_n                    # Decay Time Constant
g_fgaba = [0.8]*n_n                  # fGABA Conductance
E_fgaba = [-70.0]*n_n                # fGABA Potential

phi = 3.0**((22-T)/10)

def Na_prop(V):
    V_ = V-(-50)
    
    alpha_m = 0.32*(13.0 - V_)/(tf.exp((13.0 - V_)/4.0) - 1.0)
    beta_m = 0.28*(V_ - 40.0)/(tf.exp((V_ - 40.0)/5.0) - 1.0)
    
    alpha_h = 0.128*tf.exp((17.0 - V_)/18.0)
    beta_h = 4.0/(tf.exp((40.0 - V_)/5.0) + 1.0)
    
    t_m = 1.0/((alpha_m+beta_m)*phi)
    t_h = 1.0/((alpha_h+beta_h)*phi)
    
    return alpha_m*t_m, t_m, alpha_h*t_h, t_h


def K_prop(V):
    V_ = V-(-50)
    
    alpha_n = 0.02*(15.0 - V_)/(tf.exp((15.0 - V_)/5.0) - 1.0)
    beta_n = 0.5*tf.exp((10.0 - V_)/40.0)
    
    t_n = 1.0/((alpha_n+beta_n)*phi)

    return alpha_n*t_n, t_n


def m_a_inf(V):
    return 1/(1+tf.exp(-(V+60.0)/8.5))

def h_a_inf(V):
    return 1/(1+tf.exp((V+78.0)/6.0))

def tau_m_a(V):
    return 1/(tf.exp((V+35.82)/19.69) + tf.exp(-(V+79.69)/12.7) + 0.37) / phi
    
def tau_h_a(V):
    return tf.where(tf.less(V,-63),1/(tf.exp((V+46.05)/5) + tf.exp(-(V+238.4)/37.45)) / phi,19.0 / phi * tf.ones(tf.shape(V),dtype=V.dtype))


def m_Ca_inf(V):
    return 1/(1+tf.exp(-(V+20.0)/6.5))

def h_Ca_inf(V):
    return 1/(1+tf.exp((V+25.0)/12))

def tau_m_Ca(V):
    return 1.5 
    
def tau_h_Ca(V):
    return 0.3*tf.exp((V-40.0)/13.0) + 0.002*tf.exp((60.0-V)/29)






# NEURONAL CURRENTS

def I_Na(V, m, h):
    return g_Na * m**3 * h * (V - E_Na)

def I_K(V, n):
    return g_K  * n**4 * (V - E_K)

def I_L(V):
    return g_L * (V - E_L)

def I_KL(V):
    return g_KL * (V - E_KL)

def I_A(V, m, h):
    return g_A * m**4 * h * (V - E_A)

# SYNAPTIC CURRENTS

def I_ach(o,V):
    o_ = tf.Variable([0.0]*n_n**2,dtype=tf.float64)
    ind = tf.boolean_mask(tf.range(n_n**2),ach_mat.reshape(-1) == 1)
    o_ = tf.scatter_update(o_,ind,o)
    o_ = tf.reshape(o_,(n_n,n_n))
    return tf.reduce_sum(g_ach*o_*(V-E_ach),1)

def I_fgaba(o,V):
    o_ = tf.Variable([0.0]*n_n**2,dtype=tf.float64)
    ind = tf.boolean_mask(tf.range(n_n**2),fgaba_mat.reshape(-1) == 1)
    o_ = tf.scatter_update(o_,ind,o)
    o_ = tf.reshape(o_,(n_n,n_n))
    return tf.reduce_sum(g_fgaba*o_*(V-E_fgaba),1)


def I_inj_t(t):
    return tf.where(tf.logical_and(tf.greater(t,100),tf.less(t,400)),tf.constant(inp,dtype=tf.float64),tf.constant([0]*n_n,dtype=tf.float64))

In [198]:
def dAdt(X, t):
   
    V = X[0:n_n]
    m = X[n_n:2*n_n]
    h = X[2*n_n:3*n_n]
    n = X[3*n_n:4*n_n]
    m_a = X[4*n_n:5*n_n]
    h_a = X[5*n_n:6*n_n]
    m_a = X[6*n_n:5*n_n]
    h_a = X[5*n_n:6*n_n]
    o_ach = X[6*n_n:6*n_n+n_syn_ach]
    o_fgaba = X[6*n_n+n_syn_ach:6*n_n+n_syn_ach+n_syn_fgaba]
    fire_t = X[-n_n:]
    
    dVdt = (I_inj_t(t) - I_Na(V, m, h) - I_K(V, n) - I_A(V, m_a, h_a) - I_L(V) - I_KL(V) - I_ach(o_ach,V) - I_fgaba(o_fgaba,V)) / C_m
    
    m0,tm,h0,th = Na_prop(V)
    n0,tn = K_prop(V)

    dmdt = - (1.0/tm)*(m-m0)
    dhdt = - (1.0/th)*(h-h0)
    dndt = - (1.0/tn)*(n-n0)
    
    dm_adt = - (1.0/tau_m_a(V))*(m_a-m_a_inf(V))
    dh_adt = - (1.0/tau_h_a(V))*(h_a-h_a_inf(V))
    
    A_ = tf.constant(A,dtype=tf.float64)
    T_ach = tf.where(tf.logical_and(tf.greater(t,fire_t+t_delay),tf.less(t,fire_t+t_max+t_delay)),A_,tf.zeros(tf.shape(A_),dtype=A_.dtype))
    T_ach = tf.multiply(tf.constant(ach_mat,dtype=tf.float64),T_ach)
    T_ach = tf.boolean_mask(tf.reshape(T_ach,(-1,)),ach_mat.reshape(-1) == 1)
    do_achdt = alp_ach*(1.0-o_ach)*T_ach - bet_ach*o_ach
    
    T_fgaba = 1.0/(1.0+tf.exp(-(V-V0)/sigma))
    T_fgaba = tf.multiply(tf.constant(fgaba_mat,dtype=tf.float64),T_fgaba)
    T_fgaba = tf.boolean_mask(tf.reshape(T_fgaba,(-1,)),fgaba_mat.reshape(-1) == 1)
    do_fgabadt = alp_fgaba*(1.0-o_fgaba)*T_fgaba - bet_fgaba*o_fgaba
    
    dfdt = tf.zeros(tf.shape(fire_t),dtype=fire_t.dtype)

    out = tf.concat([dVdt,dmdt,dhdt,dndt,dm_adt,dh_adt,do_achdt,do_fgabadt,dfdt],0)
    return out

In [199]:
global n_
n_ = n_n
state_vector = [-65]*n_n + [0.05]*n_n + [0.6]*n_n + [0.32]*n_n+ [0.05]*n_n + [0.6]*n_n + [0]*(n_syn_ach) + [0]*(n_syn_fgaba) +[-500]*n_n
init_state = tf.constant(state_vector, dtype=tf.float64)
tensor_state = odeint_fixed(dAdt, init_state, t)

In [200]:
%%time
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    state = sess.run(tensor_state)

CPU times: user 7min 1s, sys: 2min 14s, total: 9min 16s
Wall time: 2min 6s


In [201]:
fig, ax = plt.subplots(3, 1,figsize=(5,5),sharex=True,sharey=True)
for n,i in enumerate(ax):
    i.plot(t[:],state[:,n])
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>