In [52]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import json
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
import time
from torch.utils.data import DataLoader, TensorDataset

In [53]:
def new_s(V_m):
    a_r = 1
    a_d = 5
    beta = 0.125
    V_th = -15 #??
    sig = 1 / (1 + np.exp(-beta * (V_m - V_th)))

    return (a_r * sig) / (a_r * sig + a_d)

In [123]:
class BetaNeuronNet(nn.Module):
    def __init__(self, G_leak, E_leak, G_syn, E_syn, G_gap):
        super().__init__()

        self.G_leak = nn.Parameter(torch.from_numpy(G_leak))
        self.E_leak = nn.Parameter(torch.from_numpy(E_leak))

        self.G_syn = nn.Parameter(torch.from_numpy(G_syn))
        self.E_syn = nn.Parameter(torch.from_numpy(E_syn))

        self.G_gap = nn.Parameter(torch.from_numpy(G_gap))

    def calc_co_syn(self, big_s):
        sum = 0
        for j in range(len(big_s)):
            sum += self.G_syn[:, j] * big_s[j]
        return sum    

    def calc_int_syn(self, big_s):
        sum = 0
        for j in range(len(big_s)):
            sum += self.G_syn[:, j] * big_s[j] * self.E_syn[j]
        return sum    

    def calc_I_syn(self, Voltage, co, int):
        return Voltage * co - int

    # Gapjn coefficent
    def calc_co_gap(self):
        return torch.sum(self.G_gap, dim=1)
    
    # Gapjn intercept
    def calc_int_gap(self, voltage):
        sum = 0
        for j in range(len(voltage)):
            sum += self.G_gap[:, j] * voltage[j]
    
        return sum
    
    # Gapjn current
    def calc_I_gap(self, voltage, co, int):
        return voltage * co - int

    def calc_I_leak(self, Voltage):    
        return self.G_leak * (Voltage - self.E_leak)

    def calc_delta_V(self, voltage, I_leak, I_gap, I_syn):
        current_sum = -I_leak - I_syn - I_gap
        return current_sum

    def calc_delta_s(self, Voltage, gate):
        a_r = 1
        a_d = 5
        beta = 0.125
        V_th = -15 #??
    
        sig = 1 / (1 + torch.exp(-beta * (Voltage - V_th)))
    
        return a_r * sig * (1 - gate) - a_d * gate
    
    def forward(self, big_V, big_s, time_step):

        leak_current = self.calc_I_leak(big_V)

        syn_co = self.calc_co_syn(big_s)
        syn_int = self.calc_int_syn(big_s)
        syn_current = self.calc_I_syn(big_V, syn_co, syn_int)

        gap_co = self.calc_co_gap()
        gap_int = self.calc_int_gap(big_V)
        gap_current = self.calc_I_gap(big_V, gap_co, gap_int)

        delta_V = self.calc_delta_V(big_V, leak_current, gap_current, syn_current)
        delta_s = self.calc_delta_s(big_V, big_s)

        new_V = big_V + (delta_V * time_step)
        new_s = big_s + (delta_s * time_step)
        
        return new_V, new_s, leak_current, syn_current 

<class 'torch.Tensor'> torch.Size([2])
<class 'torch.Tensor'> torch.Size([2])
<class 'torch.Tensor'> torch.Size([2, 2])
<class 'torch.Tensor'> torch.Size([2])
<class 'torch.Tensor'> torch.Size([2, 2])


In [128]:
%%timeit
V = torch.from_numpy(np.array([40.0, -40.0]))
s = torch.from_numpy(np.array([new_s(40.0), new_s(-40.0)]))
G_leak = np.array([10.0 for V_m in V])
E_leak = np.array([-35.0 for V_m in V])
G_syn = np.array([[0.0, 50.0], [80.0, 0.0]])
E_syn = np.array([0.0 for V_m in V])
G_gap = np.array([[0.0, 100.0], [100.0, 0.0]])

net = BetaNeuronNet(G_leak, E_leak, G_syn, E_syn, G_gap)

for i in range(30 * 1000):
    V, s, leak, syn = net(V, s, 0.001)

7.31 s ± 605 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
