In [1]:
import numpy as np

In [2]:
ETA = 1
TAU = 1
TAU_S = 1
SPIKE_THR = 0.5

In [3]:
def eta(s):
    return -ETA * np.exp(-s / TAU) * (s > 0)

def eps(s, d):
    res = (s - d) / TAU_S
    return res * np.exp(-res) * (s > 0)

In [4]:
class Layer:
    def __init__(self, in_dim, out_dim, max_time):
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.W = np.random.uniform(low=0, high=1, size=(in_dim, out_dim))
        self.D = np.random.uniform(low=1, high=5, size=(in_dim, out_dim))
        
        self.u = np.zeros((max_time, out_dim))
    
    def calc(self, t, last_time_spike_prev, last_time_spike_cur):
        for i in range(self.out_dim):
            first_add = eta(t - last_time_spike_cur[i])
            second_add = sum([self.W[j, i] * eps(t - last_time_spike_prev[j], self.D[j, i])] 
                             for j in range(self.input))
            self.u[t, i] = first_add + second_add

In [5]:
class SNN:
    def __init__(self, in_dim, hidden_dims, out_dim, max_time):
        self.layers = [Layer(in_dim, hidden_dims[0], max_time), 
                       Layer(hidden_dims[0], hidden_dims[1], max_time), 
                       Layer(hidden_dims[1], out_dim, max_time)
                      ]
        
        self.last_time_spike_inp = np.array([np.float('-inf')] * in_dim)
        self.last_time_spike_h0 = np.array([np.float('-inf')] * hidden_dims[0])
        self.last_time_spike_h1 = np.array([np.float('-inf')] * hidden_dims[1])
        self.last_time_spike_out = np.array([np.float('-inf')] * out_dim)
    
    def calc(t, pattern):
        self.last_time_spike_inp[pattern == 1] = t

        self.layers[0].calc(self, t, self.last_time_spike_inp, self.last_time_spike_h0)
        self.last_time_spike_h0[self.layers[0].u[t, :] >= SPIKE_THR] = 1
        
        self.layers[1].calc(self, t, self.last_time_spike_h0, self.last_time_spike_h1)
        self.last_time_spike_h1[self.layers[1].u[t, :] >= SPIKE_THR] = 1
        
        self.layers[2].calc(self, t, self.last_time_spike_h1, self.last_time_spike_out)
        self.last_time_spike_out[self.layers[2].u[t, :] >= SPIKE_THR] = 1
