In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms

In [2]:
device = torch.device("cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp)
    return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp))

def spike_function(x, k):
    x[x>0] = 1
    x[x<=0] = 0
    return x

def de_func(U,th):
    alpha = 0.3
    U = alpha*(1.0 - abs((U-th)/th))
    U[U<0]=0
    return U

def test():
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        output = toy(data)
        test_loss +=F.cross_entropy(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def quant(input, k):
    size = input.size()
    #mean = torch.mean(input.abs(), 1, keepdim=True)
    x = input
    #print(x)
    xmax = x.abs().max()
    num_bits=k
    v0 = 1
    v1 = 2
    v2 = -0.5
    y = k #2.**num_bits - 1.
    #print(y)
    x = x.add(v0).div(v1)
    #print(x)
    x = x.mul(y).round_()
    #print(x)
    x = x.div(y)
    #print(x)
    x = x.add(v2)
    #print(x)
    x = x.mul(v1)
    #print(x)
    input = x
    return input

In [4]:
class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,256,bias=False)
        self.fc_2 = nn.Linear(256,256,bias=False)
        self.fc_out = nn.Linear(256,10,bias=False)
        self.lif1 = LIF()
        self.lif2 = LIF()
        self.time_step = 5
        self.s_regs_inp = None
        
    def forward(self, inp):
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size)
        u_out = 0
        
        for t in range(self.time_step):
            
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            
            x = self.fc_1(spike_inp)
            x = self.lif1(x, t)
            x = self.fc_2(x)
            x = self.lif2(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out

In [5]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,512,bias=False)
        self.fc_out = nn.Linear(512,10,bias=False)
        self.lif1 = LIF()
        self.time_step = 5
        self.s_regs_inp = None
        
    def forward(self, inp):
        
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size)
        u_out = 0
        
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.fc_1(spike_inp)
            x = quant(x,4)
            x = self.lif1(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out
        

In [6]:
class LIF(nn.Module):
    def __init__(self):
        super(LIF, self).__init__()
        
        self.u_regs = None
        self.du_regs = None
        self.s_regs = None
        self.leak = 0.99
        self.time_step = 5
        self.thresh = 0.5
        
    def forward(self,inp,t):
        if t == 0:
            size = inp.shape
            self.u_regs = torch.zeros(self.time_step,*size)
            self.du_regs = torch.zeros(self.time_step,*size)
            self.u_regs[0] = inp
            self.s_regs = torch.zeros(self.time_step,*size)

            vol = inp - self.thresh

            spike = spike_function(vol, k=1)

            self.s_regs[0] = spike
        else:
            self.u_regs[t] = self.leak * self.u_regs[t-1] * (1 - self.s_regs[t-1]) + inp

            vol = self.u_regs[t] - self.thresh

            spike = spike_function(vol, k=1)

            self.s_regs[t] = spike
        return spike


In [7]:
### Back propagation for MLP
def bp_MLP(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## First fc
    du_fc1 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data -= l_r*quant(w_inp_1,4)
#     toy.fc_1.weight.data -= l_r*w_inp_1
    
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[-1])
    toy.fc_out.weight.data -= l_r*quant(w_1_out,4)
#     toy.fc_out.weight.data -= l_r*w_1_out

    for t in range(time_step-2,-1,-1):
        
        ## First fc
        ds_fc1 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data -= l_r*quant(w_inp_1,4)
#         toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[t])
        toy.fc_out.weight.data -= l_r*quant(w_1_out,4)
#         toy.fc_out.weight.data -= l_r*w_1_out


In [8]:
### Back propagation
def bp(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## Second fc    
    du_fc2 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif2.u_regs[-1],th)    
    toy.lif2.du_regs[-1] = toy.lif2.du_regs[-1] + du_fc2
    
    ## First fc
    du_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    
    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data -= l_r*w_inp_1

    w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[-1])
    toy.fc_2.weight.data -= l_r*w_1_2

    w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[-1])
    toy.fc_out.weight.data -= l_r*w_2_out

    for t in range(time_step-2,-1,-1):

        ds_fc2 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif2.du_regs[t+1]*(-leak*toy.lif2.u_regs[t])
        du_fc2 = (ds_fc2)*de_func(toy.lif2.u_regs[t],th) + toy.lif2.du_regs[t+1]*leak*(1-toy.lif2.s_regs[t])
        toy.lif2.du_regs[t] += du_fc2
        
        ## First fc
        ds_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[t])
        toy.fc_2.weight.data -= l_r*w_1_2

        w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[t])
        toy.fc_out.weight.data -= l_r*w_2_out


In [9]:
import torch
import torchvision

batch_size_train = 64
batch_size_test = 1000

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [10]:
test_losses = []
train_losses = []
train_counter = []
test_counter = [i*len(train_loader.dataset) for i in range(5 + 1)]
log_interval = 10

In [11]:
toy = MLP()
leak = 0.99
time_step = 5
lr = 0.0005
loss = nn.CrossEntropyLoss()

test()
for epoch in range(5):
    for batch_idx, (data, target) in enumerate(train_loader):
        out = toy(data)
            
        err = loss(out,target)
        exp = torch.exp(out)
        exp_sum = torch.sum(torch.exp(out),1, keepdim=True)
        target = F.one_hot(target, num_classes=10)
        L = -1*torch.sum((target*torch.log((exp/exp_sum))),1, keepdim=True)
        
        
        du_out = exp/exp_sum

        du_out = du_out - target
        
        bp_MLP(toy,leak,time_step,du_out,toy.s_regs_inp,lr,toy.lif1.thresh)
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), err.item()))
            train_losses.append(err.item())
            train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

    test()





Test set: Avg. loss: 2.3788, Accuracy: 810/10000 (8%)


Test set: Avg. loss: 0.3739, Accuracy: 8923/10000 (89%)




Test set: Avg. loss: 0.3468, Accuracy: 8961/10000 (90%)


Test set: Avg. loss: 0.3084, Accuracy: 9090/10000 (91%)




Test set: Avg. loss: 0.2990, Accuracy: 9096/10000 (91%)


Test set: Avg. loss: 0.2842, Accuracy: 9124/10000 (91%)



In [12]:
# import matplotlib.pyplot as plt
# fig = plt.figure()
# plt.plot(train_counter, train_losses, color='blue')
# plt.scatter(test_counter, test_losses, color='red')
# plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
# plt.xlabel('number of training examples seen')
# plt.ylabel('negative log likelihood loss')