In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
device = torch.device("cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp)
    return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp))

def de_func(U):
    k=1
    U = 1 - k*abs(U)
    U[U<0]=0
    return U

In [57]:
class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,4)
        self.fc_2 = nn.Linear(4,6)
        self.fc_out = nn.Linear(6,3)
        self.lif1 = LIF()
        self.lif2 = LIF()
        
    def forward(self, x, t):
        x = self.fc_1(x)
        x = self.lif1(x, t)
        x = self.fc_2(x)
        x = self.lif2(x, t)
        x = self.fc_out(x)
        
        return x

In [58]:
class LIF(nn.Module):
    def __init__(self):
        super(LIF, self).__init__()
        
        self.u_regs = None
        self.du_regs = None
        self.s_regs = None
        self.leak = 0.8
        self.time_step = 5
        self.thresh = 0.5
        
    def forward(self,inp, t):
        if t == 0:
            size = inp.shape
            self.u_regs = torch.zeros(self.time_step,*size)
            self.du_regs = torch.zeros(self.time_step,*size)
            self.u_regs[0] = inp
            self.s_regs = torch.zeros(self.time_step,*size)
            vol = inp - self.thresh
            spike = spike_function(vol, k=1)
            self.s_regs[0] = spike
        else:
            self.u_regs[t] = self.leak * self.u_regs[t-1] * (1 - self.s_regs[t-1]) + inp
            vol = self.u_regs[t] - self.thresh
            spike = spike_function(vol, k=1)
            self.s_regs[t] = spike

        return spike

        

def spike_function(x, k):
    x_fwd = torch.sign(x) * 0.5 + 0.5
    x_bwd = torch.clamp(1 - k*x.abs(), min=0.)
    return (x_fwd - x_bwd).detach() + x_bwd

In [63]:
### Back propagation
def bp(toy,leak,time_step):
    ## Use same Output derivative throught bp
    du_out = torch.exp(u_out)/torch.sum(torch.exp(u_out),1)

    ## Second fc
    du_fc2 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif2.u_regs[-1])
    toy.lif2.du_regs[-1].data += du_fc2
    ## First fc
    du_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)*de_func(toy.lif1.u_regs[-1])
    #print(du_fc1)
    toy.lif1.du_regs[-1].data += du_fc1
    #print(toy.lif1.du_regs[-1])

    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data += w_inp_1

    w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[-1])
    toy.fc_2.weight.data += w_1_2

    w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[-1])
    toy.fc_out.weight.data += w_2_out

    for t in range(time_step-2,-1,-1):
        ds_fc2 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif2.du_regs[t+1]*(-leak*toy.lif2.u_regs[t])
        du_fc2 = (ds_fc2)*de_func(toy.lif2.u_regs[t]) + toy.lif2.du_regs[t+1]*leak*(1-toy.lif2.s_regs[t])
        toy.lif2.du_regs[t].data += du_fc2
        ## First fc
        ds_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t]) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        #print(du_fc1)
        toy.lif1.du_regs[t].data += du_fc1
        #print(toy.lif1.du_regs[-1])

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data += w_inp_1

        w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[t])
        toy.fc_2.weight.data += w_1_2

        w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[t])
        toy.fc_out.weight.data += w_2_out


In [64]:
inp = torch.rand(1,28,28,requires_grad = True)
inp = inp.view(1,-1)

toy = model()
u_out = 0
time_step = 5
leak = 0.8

size = inp.shape
#s_regs_inp = torch.zeros_like(inp).expand(time_step, *size)
s_regs_inp = torch.zeros(time_step,*size)


for t in range(time_step):
    spike_inp = PoissonGen(inp)
  
    s_regs_inp[t].data += spike_inp
    
    u_out = u_out + toy(spike_inp,t)
loss = nn.CrossEntropyLoss()
target = torch.tensor([2])
#l = -1*u_out[0][1]+torch.log(torch.exp(u_out[0][0])+torch.exp(u_out[0][1]))
err = loss(u_out,target)
# print("err:",err)
# print("u_out:",u_out)
#torch.exp(u_out[0][0])/(torch.exp(u_out[0][0])+torch.exp(u_out[0][1])+torch.exp(u_out[0][2]))-0
print(err)
bp(toy,leak,time_step)
print(toy.fc_out.weight)

tensor(1.6726, grad_fn=<NllLossBackward>)
Parameter containing:
tensor([[ 1.1844,  1.0711,  0.3113, -0.2409, -0.0202,  0.0259],
        [ 0.9232,  0.7646,  0.1416,  0.3941, -0.3911, -0.3959],
        [ 0.7645,  0.4797,  0.0421, -0.2835, -0.0146, -0.1894]],
       requires_grad=True)
