In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms

In [2]:
#device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp)
    return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp))

def spike_function(x, k):
    x[x>0] = 1
    x[x<=0] = 0
    return x

def de_func(U,th):
    alpha = 0.4
    U = alpha*(1.0 - abs((U-th)/th))
    U[U<0]=0
    return U

def test(toy):
    test_loss = 0
    correct = 0
    toy = toy.cuda()
    for data, target in test_loader:
        data = data.cuda()
        target = target.cuda()
        output = toy(data)
        test_loss +=F.cross_entropy(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def quant(input, k):
    size = input.size()
    #mean = torch.mean(input.abs(), 1, keepdim=True)
    x = input
    #print(x)
    xmax = x.abs().max()
    num_bits=k
    v0 = 1
    v1 = 2
    v2 = -0.5
    y = k #2.**num_bits - 1.
    #print(y)
    x = x.add(v0).div(v1)
    #print(x)
    x = x.mul(y).round_()
    #print(x)
    x = x.div(y)
    #print(x)
    x = x.add(v2)
    #print(x)
    x = x.mul(v1)
    #print(x)
    input = x
    return input

In [4]:
class model(nn.Module):
    def __init__(self, time_step,leak):
        super(model, self).__init__()
        
       
        self.fc_1 = nn.Linear(28*28,256,bias=False)
        self.fc_2 = nn.Linear(256,256,bias=False)
        self.fc_out = nn.Linear(256,10,bias=False)
        
        self.lif1 = LIF(time_step,leak)
        self.lif2 = LIF(time_step,leak)
        self.time_step = time_step
        self.s_regs_inp = None
        
    def forward(self, inp):
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        
        for t in range(self.time_step):
            
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            
            x = self.fc_1(spike_inp)
            #x = quant(x,2**4)
            x = self.lif1(x, t)
            x = self.fc_2(x)
            #x = quant(x,2**4)
            x = self.lif2(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out/self.time_step

In [5]:
class MLP(nn.Module):
    def __init__(self,time_step,leak):
        super(MLP, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,512,bias=False)
        self.fc_out = nn.Linear(512,10,bias=False)
        self.lif1 = LIF(time_step,leak)
        self.time_step = time_step
        self.s_regs_inp = None
        
    def forward(self, inp):
#         print("size is:", (inp.view(inp.shape[0],1,28,28)).shape)
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.fc_1(spike_inp)
            #x = quant(x,2**4)
            x = self.lif1(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out
        

In [6]:
class VGG_5(nn.Module):
    def __init__(self,time_step):
        super(VGG_5, self).__init__()
        
        self.time_step = time_step
        self.s_regs_inp = None
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        self.lif1 = LIF(time_step)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
        self.lif2 = LIF(time_step)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False)
        self.lif3 = LIF(time_step)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.fc1 = nn.Linear(128 * 7 * 7, 1024, bias=False)
        self.lif4 = LIF(time_step)
        self.fc_out = nn.Linear(1024, 10, bias=False)
        
    def forward(self, inp):
        
#         inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.conv1(spike_inp)
            x = self.lif1(x,t)
            x = self.pool1(x)
            x = self.conv2(x)
            x = self.lif2(x,t)
            x = self.conv3(x)
            x = self.lif3(x,t)
            x = self.pool2(x)
            x = x.view(x.shape[0],-1)
            print(x.shape)
            x = self.fc1(x)
            x = self.lif4(x,t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out

In [None]:
class VGG_1(nn.Module):
    def __init__(self,time_step):
        super(VGG_5, self).__init__()
        
        self.time_step = time_step
        self.s_regs_inp = None
        self.s_regs_conv = None
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        self.lif_conv1 = LIF(time_step)
        self.pool1 = nn.MaxPool2d(kernel_size=2,return_indices=True)

        self.fc1 = nn.Linear(128 * 14 * 14, 1024, bias=False)
        self.lif_fc1 = LIF(time_step)
        self.fc_out = nn.Linear(1024, 10, bias=False)
        
    def forward(self, inp):

        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        
        u_out = 0
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.conv1(spike_inp)
            x = self.lif_conv1(x,t)
            x = self.pool1(x)
            x = x.view(x.shape[0],-1)
            if t == 0:
                self.s_regs_conv = torch.zeros(self.time_step,*x.shape, device=device)
            self.s_regs_conv[t] += x
            x = self.fc1(x)
            x = self.lif_fc1(x,t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out

In [7]:
def bp_VGG1(vgg,leak,time_step,du_out,s_regs_conv,l_r,th):
    
    ## First fc
    du_fc1 = torch.matmul(du_out,vgg.fc_out.weight)*de_func(vgg.lif_fc1.u_regs[-1],th)
    vgg.lif_fc1.du_regs[-1] += du_fc1
    
    ## Update weight
    w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_conv[-1])
    vgg.fc_1.weight.data -= l_r*w_conv_1
    
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif_fc1.s_regs[-1])
    vgg.fc_out.weight.data -= l_r*w_1_out
    
    
    
    
    
    

In [58]:
pool = nn.MaxPool2d(2, stride=2, return_indices=True)
unpool = nn.MaxUnpool2d(2, stride=2)
input = torch.tensor([[[[ 1.,  2,  3,  4],
                            [ 5,  6,  7,  8],
                            [ 9, 10, 11, 12],
                            [13, 14, 15, 16]]]])
output, indices = pool(input)

unpool(output,indices)

tensor([[[[ 0.,  0.,  0.,  0.],
          [ 0.,  6.,  0.,  8.],
          [ 0.,  0.,  0.,  0.],
          [ 0., 14.,  0., 16.]]]])

In [8]:
class LIF(nn.Module):
    def __init__(self, time_step,leak):
        super(LIF, self).__init__()
        
        self.u_regs = None
        self.du_regs = None
        self.s_regs = None
        self.leak = leak
        self.time_step = time_step
        self.thresh = 0.5
        
    def forward(self,inp,t):
        if t == 0:
            size = inp.shape
            self.u_regs = torch.zeros(self.time_step,*size, device=device)
            self.du_regs = torch.zeros(self.time_step,*size, device=device)
            #inp = inp + torch.norm(inp)
            err = torch.normal(0, 0.1,(1,1)).cuda()
            inp = inp + err
            self.u_regs[0] = quant(inp,2**4)
            self.s_regs = torch.zeros(self.time_step,*size, device=device)

            vol = inp - self.thresh

            spike = spike_function(vol, k=1)

            self.s_regs[0] = spike
        else:
            err = torch.normal(0, 0.1,(1,1)).cuda()
            inp = inp + err
            self.u_regs[t] = quant(self.leak * self.u_regs[t-1] * (1 - self.s_regs[t-1]) + (1-self.leak)*inp, 2**4)     
            vol = self.u_regs[t] - self.thresh

            spike = spike_function(vol, k=1)

            self.s_regs[t] = spike
        return spike


In [9]:
### Back propagation for MLP
def bp_MLP(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## First fc
    du_fc1 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
#     toy.fc_1.weight.data -= l_r*w_inp_1
    
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[-1])
    toy.fc_out.weight.data -= l_r*quant(w_1_out,2**4)
#     toy.fc_out.weight.data -= l_r*w_1_out

    for t in range(time_step-2,-1,-1):
        
        ## First fc
        ds_fc1 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
        #toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[t])
        toy.fc_out.weight.data -= l_r*quant(w_1_out,2**4)
#         toy.fc_out.weight.data -= l_r*w_1_out


In [10]:
### Back propagation
def bp(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## Second fc    
    du_fc2 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif2.u_regs[-1],th)    
    toy.lif2.du_regs[-1] = toy.lif2.du_regs[-1] + du_fc2
    
    ## First fc
    du_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    
    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
    #toy.fc_1.weight.data -= l_r*w_inp_1

    w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[-1])
    toy.fc_2.weight.data -= l_r*quant(w_1_2,2**4)
    #toy.fc_2.weight.data -= l_r*w_1_2

    w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[-1])
    toy.fc_out.weight.data -= l_r*quant(w_2_out,2**4)
    #toy.fc_out.weight.data -= l_r*w_2_out

    for t in range(time_step-2,-1,-1):

        ds_fc2 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif2.du_regs[t+1]*(-leak*toy.lif2.u_regs[t])
        du_fc2 = (ds_fc2)*de_func(toy.lif2.u_regs[t],th) + toy.lif2.du_regs[t+1]*leak*(1-toy.lif2.s_regs[t])
        toy.lif2.du_regs[t] += du_fc2
        
        ## First fc
        ds_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
        #toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[t])
        toy.fc_2.weight.data -= l_r*quant(w_1_2,2**4)
        #toy.fc_2.weight.data -= l_r*w_1_2

        w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[t])
        toy.fc_out.weight.data -= l_r*quant(w_2_out,2**4)
        #toy.fc_out.weight.data -= l_r*w_2_out


In [11]:
import torch
import torchvision

batch_size_train = 216
batch_size_test = 1000

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [12]:
test_losses = []
train_losses = []
train_counter = []
test_counter = [i*len(train_loader.dataset) for i in range(5 + 1)]
log_interval = 10

In [13]:
time_step = 16
leak = 0.98
toy = MLP(time_step,leak)
toy= toy.cuda()
# vgg = VGG_5(time_step)
# vgg =vgg.cuda()
# print("weight",toy.fc_1.weight)
torch.nn.init.normal_(toy.fc_1.weight, mean=0.0, std=0.1)
toy.fc_1.weight.data = quant(toy.fc_1.weight,2**4)
# torch.nn.init.normal_(toy.fc_2.weight, mean=0.0, std=0.1)
# toy.fc_2.weight.data = quant(toy.fc_2.weight,2**4)
torch.nn.init.normal_(toy.fc_out.weight, mean=0.0, std=0.1)
toy.fc_out.weight.data = quant(toy.fc_out.weight,2**4)
# print("quantized weight",toy.fc_1.weight)
lr = 0.001
loss = nn.CrossEntropyLoss()

test(toy)
for epoch in range(15):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.cuda()
        target = target.cuda()
        out = toy(data)
         
        err = loss(out,target)
        exp = torch.exp(out)
        exp_sum = torch.sum(torch.exp(out),1, keepdim=True)   
        target = F.one_hot(target, num_classes=10)
        #L = -1*torch.sum((target*torch.log((exp/exp_sum))),1, keepdim=True)
        du_out = exp/exp_sum
        du_out = du_out - target
        
        
#         vgg_out = vgg(data)
#         exp_vgg = torch.exp(vgg_out)
#         exp_sum_vgg = torch.sum(torch.exp(vgg_out),1, keepdim=True)
#         du_out_vgg = exp_vgg/exp_sum_vgg
#         du_out_vgg = du_out_vgg - target
#         print(du_out_vgg)
        
        
        
        bp_MLP(toy,leak,time_step,du_out,toy.s_regs_inp,lr,toy.lif1.thresh)
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), err.item()))
            train_losses.append(err.item())
            train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

    test(toy)





Test set: Avg. loss: 3.1644, Accuracy: 1041/10000 (10%)


Test set: Avg. loss: 0.1602, Accuracy: 9507/10000 (95%)


Test set: Avg. loss: 0.1189, Accuracy: 9609/10000 (96%)


Test set: Avg. loss: 0.1010, Accuracy: 9681/10000 (97%)


Test set: Avg. loss: 0.0960, Accuracy: 9683/10000 (97%)


Test set: Avg. loss: 0.0983, Accuracy: 9688/10000 (97%)




Test set: Avg. loss: 0.0903, Accuracy: 9705/10000 (97%)


Test set: Avg. loss: 0.0794, Accuracy: 9746/10000 (97%)


Test set: Avg. loss: 0.0858, Accuracy: 9717/10000 (97%)


Test set: Avg. loss: 0.0768, Accuracy: 9770/10000 (98%)


Test set: Avg. loss: 0.0817, Accuracy: 9744/10000 (97%)


Test set: Avg. loss: 0.0796, Accuracy: 9765/10000 (98%)




Test set: Avg. loss: 0.0806, Accuracy: 9763/10000 (98%)


Test set: Avg. loss: 0.0784, Accuracy: 9753/10000 (98%)


Test set: Avg. loss: 0.0734, Accuracy: 9781/10000 (98%)


Test set: Avg. loss: 0.0753, Accuracy: 9780/10000 (98%)



In [14]:
# import matplotlib.pyplot as plt
# fig = plt.figure()
# plt.plot(train_counter, train_losses, color='blue')
# plt.scatter(test_counter, test_losses, color='red')
# plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
# plt.xlabel('number of training examples seen')
# plt.ylabel('negative log likelihood loss')

In [41]:
W = torch.tensor([[1,2],[3,4]])
f = W.shape[-1]
dH = torch.tensor([[1,2],[3,4]])
n_W = dH.shape[-1]
n_H = dH.shape[-2]
X = torch.ones(3,3)
dX = torch.zeros(X.shape)
dW = torch.zeros(W.shape)

In [42]:
for h in range(n_H):
    for w in range(n_W):
        dX[h:h+f, w:w+f] += W * dH[h][w]
        dW += X[h:h+f, w:w+f] * dH[h][w]

In [43]:
dW

tensor([[10., 10.],
        [10., 10.]])

In [45]:
dX[0:0+2, 0:0+2]

tensor([[ 1.,  4.],
        [ 6., 20.]])

In [46]:
W*dH[0][0]

tensor([[1, 2],
        [3, 4]])