In [None]:
### imports ###
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.parameter import Parameter
from torch import optim
import torch.nn.functional as F

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import numpy as np
import math

print("PyTorch version: %s" % torch.__version__)
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device: %s" % dev)

from IPython.core.debugger import set_trace



### classes ###

# watch out for neuron firing before spike arriving at the target of its synapses
class Recurrent(nn.Module):
 
    def __init__(self, neuron_groups, delta_t, noise_tau, trail_num=1, w_ij=None, axon_delays=None):
        super(Recurrent, self).__init__()
        
        self.num_neurons = neuron_groups[:, 0].type(torch.IntTensor)
        self.tot_neurons = torch.sum(self.num_neurons).item()
        self.trial_num = trail_num
        
        self.d_t = torch.tensor(delta_t, device=dev)
        self.n_sqdt_tau = torch.tensor(math.sqrt(delta_t) / noise_tau, device=dev)
        self.n_decay = torch.tensor(math.exp(-delta_t / noise_tau), device=dev)
        
        self.n_m = Parameter(self.d_t * torch.eye(self.tot_neurons, device=dev))
        self.subs = torch.ones((self.trial_num, self.tot_neurons, self.tot_neurons), device=dev)
        self.relu = nn.ReLU()
        
        self.threshold = torch.empty((self.trial_num, self.tot_neurons), device=dev)
        self.dt_tau = torch.empty((self.trial_num, self.tot_neurons), device=dev)
        self.decay = torch.empty((self.trial_num, self.tot_neurons), device=dev)
        self.refractory = torch.empty((self.trial_num, self.tot_neurons), device=dev)
        p_sz = 0
        cnt = 0
        for sz in self.num_neurons:
            self.threshold[:, p_sz : p_sz + sz] = neuron_groups[cnt, 1]
            self.dt_tau[:, p_sz : p_sz + sz] = delta_t / neuron_groups[cnt, 2]
            self.decay[:, p_sz : p_sz + sz] = math.exp(-delta_t / neuron_groups[cnt, 2])
            self.refractory[:, p_sz : p_sz + sz] = neuron_groups[cnt, 3]
            p_sz += sz
            cnt += 1
        self.num_groups = cnt

        self.grad_clip = torch.ones((self.tot_neurons, self.tot_neurons), device=dev)
        if w_ij is not None:
            self.w = Parameter(torch.empty((self.tot_neurons, self.tot_neurons), device=dev))
            self.axon_delays = torch.empty((self.tot_neurons, self.tot_neurons), device=dev)
            for i in range(self.tot_neurons):
                for j in range(self.tot_neurons):
                    if w_ij[i,j] == 0:
                        self.grad_clip[i,j] = 0.0
                    self.w.data[i,j] = w_ij[i,j]
                    self.axon_delays.data[i,j] = c_ij[i,j]
        else:
            self.w = Parameter(torch.randn((self.tot_neurons, self.tot_neurons), device=dev))
            self.axon_delays = torch.randint(1, 100, (self.tot_neurons, self.tot_neurons), device=dev)
        
        # generate the boolean indicator for E/I cells
        self.dale = torch.zeros(self.tot_neurons).type(torch.ByteTensor)
        for j in range(self.tot_neurons):
            if torch.sum(self.w[:,j]) > 0:
                 self.dale[j] = True
                
    def reset(self, init_fluc):
        self.h = (2.0 * torch.rand((self.trial_num, self.tot_neurons), device=dev) - 1.0) * init_fluc
        self.eta = torch.zeros((self.trial_num, self.tot_neurons), device=dev)
        self.refract = torch.zeros((self.trial_num, self.tot_neurons), device=dev)
        self.axons = torch.zeros((self.trial_num, self.tot_neurons, self.tot_neurons), device=dev)
        self.fires = (self.h >= self.threshold)
 
    def constraints(self):
        self.w.data = self.w.data * self.grad_clip
        self.n_m.data = torch.triu(self.n_m.data)
        
        # enforce Dale's law by clamping
        for j in range(self.tot_neurons):
            if self.dale[j]: # E
                self.w.data[self.w.data[:,j] < 0, j] = 0
            else: # I
                self.w.data[self.w.data[:,j] > 0, j] = 0
            # must set .data to value, otherwise we get rid of leaf node! Autograd complains
        
    def forward(self, u_ext):
        noise = torch.randn((self.trial_num, self.tot_neurons), device=dev)
        self.eta = self.n_decay * self.eta + self.n_sqdt_tau * torch.matmul(noise, self.n_m.t())
        
        self.h[self.fires] = 0.0
        for t in range(self.trial_num):
            for j in range(self.tot_neurons):
                if self.fires[t,j]: # only one AP down axons needs high enough refractory
                    self.axons[t, j] = self.axon_delays[:, j] # axons tji, delays ij labeling
        
        spikes = (self.axons == 1).sum(1).type(torch.cuda.FloatTensor) # presynaptic spikes
        self.axons = self.relu(self.axons - self.subs)
        
        refr = (self.refract > 0)
        refr_f = ~refr
        dh = (self.decay * self.h + self.dt_tau * (torch.matmul(spikes, self.w.t()) + u_ext + self.eta))
        self.h[refr_f] = dh[refr_f]
        
        self.fires = (self.h >= self.threshold) # soma fires
        self.refract[self.fires] = self.refractory[self.fires]
        self.refract = self.refract - refr.to(dev).type(torch.cuda.FloatTensor)

        
class FeedForward(nn.Module):
    
    def __init__(self, neuron_groups, neurons_out, trail_num, w_ij, c_ij):
        super(FeedForward, self).__init__()
        
        self.num_neurons = neuron_groups[:, 0].type(torch.int8)
        self.neurons_in = torch.sum(self.num_neurons).item()
        self.trial_num = trail_num
        
        self.alpha = torch.empty((trial_num, self.neurons_in), device=dev)
        p_sz = 0
        cnt = 0
        for sz in self.num_neurons:
            self.alpha[:, p_sz : p_sz + sz] = neuron_groups[cnt, 1]
            p_sz += sz
            cnt += 1
        
        self.w = Parameter(torch.empty((neurons_out, neurons_in), device=dev))
        self.axon_delays = torch.empty((self.tot_neurons, self.tot_neurons), device=dev)
        
        self.subs = torch.ones((self.trial_num, self.tot_neurons, self.tot_neurons), device=dev)
        self.relu = nn.ReLU()
        self.grad_clip = torch.zeros((neurons_out, neurons_in), device=dev)
        
        for i in range(neurons_out):
            for j in range(self.neurons_in):
                if w_ij[i,j] == 0:
                    self.grad_clip[i,j] = 1.0
                self.w.data[i,j] = w_ij[i,j]
                self.axon_delays.data[i,j] = c_ij[i,j]                 
        
        # generate the boolean indicator for E/I cells
        self.dale = torch.zeros(neurons_in).type(torch.ByteTensor)
        for j in range(self.neurons_in):
            if torch.sum(self.w[:,j]) > 0:
                self.dale[j] = True
    
    def reset(self, init_fluc):
        self.axons = torch.zeros((self.trial_num, self.tot_neurons, self.tot_neurons), device=dev)
        
    def constraints(self):
        self.w.data = self.w.data * self.grad_clip
        # enforce Dale's law by clamping
        for j in range(self.neurons_in):
            if self.dale[j]: # E
                self.w.data[self.w.data[:,j] < 0, j] = 0
            else: # I
                self.w.data[self.w.data[:,j] > 0, j] = 0
        
    def forward(self, fires):
        for t in range(self.trial_num):
            for j in range(self.tot_neurons):
                if fires[t,j]:
                    self.axons[t, j] = self.axon_delays[:, j] # axons tji, delays ij labeling
        
        spikes = (self.axons == 1).sum(1).type(torch.cuda.FloatTensor) # presynaptic spikes
        self.axons = self.relu(self.axons - self.subs)
        
        return torch.matmul(spikes, self.w.t())
        

### initialization ###
T = 5000
del_T = 1
offset = 0
latent_size = 50

d_t = 0.2
neuron_data = torch.tensor([[latent_size, 0.01, 20.0, 100], [latent_size, 0.2, 10.0, 200]])
neurons = int(torch.sum(neuron_data[:, 0]).item())

w_ij = torch.zeros((neurons, neurons))
w_ij[:, 0:latent_size] = 10.0
w_ij[:, latent_size:neurons] = -10.0
for k in range(0, neurons):
    w_ij[k, k] = 0.0
    
c_ij = torch.zeros((neurons, neurons))
c_ij[:, 0:latent_size] = 3
c_ij[:, latent_size:neurons] = 1
for k in range(0, neurons):
    c_ij[k, k] = 0

trials = 1

model = Recurrent(neuron_data, d_t, 20.0, trials)#, w_ij, c_ij)
        

#I = torch.tensor([1.0,1.0,1.0, 0.0,0.0,0.0], device=dev)
### simulation ###
h_sav = []
f_sav = []
model.reset(0.0)
model.constraints()

I = torch.tensor([10.0,10.0,10.0,0.0,0.0,0.0], device=dev)
for t in range(T + 1):
    if t % del_T == 0:
        h_sav.append(model.h.cpu().data.numpy())
        f_sav.append(model.fires.cpu().data.numpy())
    if t == T:
        break
    if t < 0:
        model(0.0)
    else:
        model(0)

#h_sav = torch.stack(h_sav)
h = np.asarray(h_sav)
f = np.asarray(f_sav)



### visualization ###
v = range(h.shape[0])
plt.plot(v, h[:, 0, 0])

width = h.shape[0] / h.shape[2]
height =  5
plt.figure(figsize=(width, height))
a = f[:,0,:]
plt.imshow(a.transpose(), interpolation="nearest", cmap=plt.cm.gray)
plt.colorbar()
plt.xlabel("Timestep")
plt.ylabel("Neuron")
plt.show()
