In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import gc
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
#device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [331]:
def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp)
    return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp))

# def spike_function(x):
#     x[x>0] = 1
#     x[x<=0] = 0
#     return x

def de_func(U,th):
    alpha = 0.3
    U = alpha*(1.0 - abs((U-th)/th))
    U[U<0]=0
    return U

def test(toy,data,test_loader):
    
    test_loss = 0
    correct = 0
    toy = toy.cuda()
    for data, target in test_loader:
        data = data.cuda()
        target = target.cuda()
        output = toy(data)
        test_loss +=F.cross_entropy(output, target,reduction='mean').item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    
#     correct = 0
#     total = 0
#     # since we're not training, we don't need to calculate the gradients for our outputs
#     with torch.no_grad():
#         for data in test_loader:
#             toy = toy.cuda()
            
#             images, labels = data
#             images = images.cuda()
#             labels = labels.cuda()
#             # calculate outputs by running images through the network
#             outputs = toy(images)
#             # the class with the highest energy is what we choose as prediction
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()

#     print('Accuracy of the network on the 10000 test images: %d %%' % (
#         100 * correct / total))

def quant(input, k):
    size = input.size()
    #mean = torch.mean(input.abs(), 1, keepdim=True)
    x = input
    #print(x)
    xmax = x.abs().max()
    num_bits=k
    v0 = 1
    v1 = 2
    v2 = -0.5
    y = k #2.**num_bits - 1.
    #print(y)
    x = x.add(v0).div(v1)
    #print(x)
    x = x.mul(y).round_()
    #print(x)
    x = x.div(y)
    #print(x)
    x = x.add(v2)
    #print(x)
    x = x.mul(v1)
    #print(x)
    input = x
    return input

def conv_weight_update(dH,X,pad):
    shap = dH.shape[1]
    batch = dH.shape[0]
    shap_X = X.shape[1]
#     dH = torch.sum(dH,0)
#     dH = dH/batch
#     dH = torch.unsqueeze(dH,1)
    dH = torch.repeat_interleave(dH,X.shape[1],1)
    dH = dH.view(1,batch*dH.shape[1],dH.shape[-1],dH.shape[-1])
#     X = torch.sum(X,0)
#     X = torch.unsqueeze(X,0)
    X = X.repeat(1,shap,1,1)
    X = X.view(1,batch*X.shape[1],X.shape[-1],X.shape[-1])
    dw_conv = F.conv2d(X,dH,padding=pad,groups=X.shape[1])
    
    dw_conv = torch.sum(dw_conv,0)
    dw_conv = dw_conv.view(shap,shap_X,dw_conv.shape[-1],dw_conv.shape[-1])
    return dw_conv/batch

def conv_dx_update(dH,W,pad):

    W = torch.transpose(W,0,1)
    W = torch.flip(W,[-1,-2]) # W = C*3*3
    dx_conv = F.conv2d(dH,W,padding=1)
    
    return dx_conv

In [4]:
class model(nn.Module):
    def __init__(self, time_step,leak):
        super(model, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,256,bias=False)
        self.fc_2 = nn.Linear(256,256,bias=False)
        self.fc_out = nn.Linear(256,10,bias=False)
        
        self.lif1 = LIF(time_step,leak)
        self.lif2 = LIF(time_step,leak)
        self.time_step = time_step
        self.s_regs_inp = None
        
    def forward(self, inp):
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        
        for t in range(self.time_step):
            
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            
            x = self.fc_1(spike_inp)
            #x = quant(x,2**4)
            x = self.lif1(x, t)
            x = self.fc_2(x)
            #x = quant(x,2**4)
            x = self.lif2(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out/self.time_step

In [5]:
class MLP(nn.Module):
    def __init__(self,time_step,leak):
        super(MLP, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,512,bias=False)
        self.fc_out = nn.Linear(512,10,bias=False)
        self.lif1 = LIF(time_step,leak)
        self.time_step = time_step
        self.s_regs_inp = None
        
    def forward(self, inp):
#         print("size is:", (inp.view(inp.shape[0],1,28,28)).shape)
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.fc_1(spike_inp)
            #x = quant(x,2**4)
            x = self.lif1(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out
        

In [6]:
class VGG_5(nn.Module):
    def __init__(self,time_step, leak,data):
        super(VGG_5, self).__init__()
        
        if data == "cifar10":
            input_dim = 3
            pre_linear_dim = 8
        elif data == "mnist":
            input_dim = 1
            pre_linear_dim = 7
        
        self.time_step = time_step
        self.s_regs_inp = None
        self.s_regs_conv = None
        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False)
        self.conv_lif1 = LIF(time_step, leak)
        # self.conv1a = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
        # self.conv_lif1a = LIF(time_step, leak)
        self.pool1 = nn.MaxPool2d(kernel_size=2,return_indices=True)
        self.pool1_ind = []
        self.unpool1 = nn.MaxUnpool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
        self.conv_lif2 = LIF(time_step, leak)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
        self.conv_lif3 = LIF(time_step, leak)
        self.pool2 = nn.MaxPool2d(kernel_size=2,return_indices=True)
        self.pool2_ind = []
        self.unpool2 = nn.MaxUnpool2d(kernel_size=2)

        self.fc1 = nn.Linear(128 * pre_linear_dim * pre_linear_dim, 1024, bias=False)
        self.fc_lif1 = LIF(time_step,leak)
        self.fc_out = nn.Linear(1024, 10, bias=False)
        
    def forward(self, inp):

        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        self.pool1_ind = []
        self.pool2_ind = []
        u_out = 0
        
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.conv1(spike_inp)
            x = self.conv_lif1(x,t)
            # x = self.conv1a(x)
            # x = self.conv_lif1a(x,t)
            x,indices = self.pool1(x)
            self.pool1_ind.append(indices)
            x = self.conv2(x)
            x = self.conv_lif2(x,t)
            x = self.conv3(x)
            x = self.conv_lif3(x,t)
            x,indices = self.pool2(x)
            x = x.view(x.shape[0],-1)
            
            if t == 0:
                self.s_regs_conv = torch.zeros(self.time_step,*x.shape, device=device)
            self.pool2_ind.append(indices)
            self.s_regs_conv[t] += x
            
            x = self.fc1(x)
            x = self.fc_lif1(x,t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out

In [7]:
def bp_VGG5(vgg,leak,time_step,du_out,l_r,mom,lamb,th,pre_linear_dim,input_dim):
    
#     batch = du_out.shape[0]
    
    ## Update weight in FCs, time T
    du_fc1 = torch.matmul(du_out,vgg.fc_out.weight)*de_func(vgg.fc_lif1.u_regs[-1],th)
    vgg.fc_lif1.du_regs[-1] += du_fc1
    w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[-1])
    vgg.fc1.weight.data -= l_r*(w_conv_1 + lamb*torch.abs(vgg.fc1.weight))    
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),vgg.fc_lif1.s_regs[-1])
    vgg.fc_out.weight.data -= l_r*(w_1_out + lamb*torch.abs(vgg.fc_out.weight))
    
    ## Update du in pool2, time T
    dx_pool2 = torch.matmul(du_fc1,vgg.fc1.weight)
    dx_pool2 = dx_pool2.view(dx_pool2.shape[0],128,pre_linear_dim,pre_linear_dim)
    du_pool2 = vgg.unpool2(dx_pool2,vgg.pool2_ind[-1])
    
    ## Update du and dw in conv3, time T
    du_conv3 = du_pool2*de_func(toy.conv_lif3.u_regs[-1],th)
    vgg.conv_lif3.du_regs[-1] += du_conv3 
    dW_conv3 = conv_weight_update(du_pool2.type(torch.float),vgg.conv_lif2.s_regs[-1].type(torch.float),1)
#     dW_conv3 = torch.sum(dW_conv3,0)
#     dW_conv3 = dW_conv3.view(128,64,dW_conv3.shape[-1],dW_conv3.shape[-1])
    vgg.conv3.weight.data -=l_r*(dW_conv3 + lamb*torch.abs(vgg.conv3.weight))
    
    ## Update du and dw in conv2, time T
    
    du_conv2 = conv_dx_update(du_conv3,vgg.conv3.weight,'same')*de_func(toy.conv_lif2.u_regs[-1],th)
    vgg.conv_lif2.du_regs[-1] += du_conv2
    dW_conv2 = conv_weight_update(du_conv2.type(torch.float),F.max_pool2d(vgg.conv_lif1.s_regs[-1].type(torch.float),kernel_size=2),1)
#     dW_conv2 = torch.sum(dW_conv2,0)
#     dW_conv2 = dW_conv2.view(64,64,dW_conv2.shape[-1],dW_conv2.shape[-1])
    vgg.conv2.weight.data -=l_r*(dW_conv2 + lamb*torch.abs(vgg.conv2.weight))
    
    ## Update du in pool2, time t
    du_pool1 = vgg.unpool1(conv_dx_update(du_conv2,vgg.conv2.weight,'same'),vgg.pool1_ind[-1])
    
    # du_conv1a = du_pool1*de_func(toy.conv_lif1a.u_regs[-1],th)
    # vgg.conv_lif1a.du_regs[-1] += du_conv1a
    # dW_conv1a = conv_weight_update(du_pool1.type(torch.float),vgg.conv_lif1.s_regs[-1].type(torch.float),1)
    # dW_conv1a = torch.sum(dW_conv1a,0)
    # dW_conv1a = dW_conv1a.view(64,32,dW_conv1a.shape[-1],dW_conv1a.shape[-1])
    # vgg.conv1a.weight.data -=l_r*dW_conv1a
    
    du_conv1 = du_pool1*de_func(toy.conv_lif1.u_regs[-1],th)
    vgg.conv_lif1.du_regs[-1] += du_conv1
    dW_conv1 = conv_weight_update(du_conv1.type(torch.float),vgg.s_regs_inp[-1].type(torch.float),1)
#     dW_conv1 = torch.sum(dW_conv1,0)
#     dW_conv1 = dW_conv1.view(64,input_dim,dW_conv1.shape[-1],dW_conv1.shape[-1])
    vgg.conv1.weight.data -=l_r*(dW_conv1 + lamb*torch.abs(vgg.conv1.weight))
    
    prev_fc1 = w_conv_1
    prev_fc_out = w_1_out
    prev_conv3 = dW_conv3
    prev_conv2 = dW_conv2
    prev_conv1 = dW_conv1


    for t in range(time_step-2,-1,-1):
        
        ds_fc1 = torch.matmul(du_out,vgg.fc_out.weight)+vgg.fc_lif1.du_regs[t+1]*(-leak*vgg.fc_lif1.du_regs[t])
        du_fc1 = (ds_fc1)*de_func(vgg.fc_lif1.du_regs[t],th) + vgg.fc_lif1.du_regs[t+1]*leak*(1-vgg.fc_lif1.s_regs[t])
        vgg.fc_lif1.du_regs[t] += du_fc1
        w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[t]) + mom*prev_fc1
        prev_fc1 = w_conv_1
        vgg.fc1.weight.data -= l_r*(w_conv_1 + lamb*torch.abs(vgg.fc1.weight)) 
        w_1_out = torch.matmul(torch.transpose(du_out,0,1),vgg.fc_lif1.s_regs[t]) + mom*prev_fc_out
        prev_fc_out = w_1_out
        vgg.fc_out.weight.data -= l_r*(w_1_out + lamb*torch.abs(vgg.fc_out.weight))
        
        
        dx_pool2 = torch.matmul(du_fc1,vgg.fc1.weight)
        dx_pool2 = dx_pool2.view(dx_pool2.shape[0],128,pre_linear_dim,pre_linear_dim)
        du_pool2 = vgg.unpool2(dx_pool2,vgg.pool2_ind[t])
        ds_conv3 = du_pool2+vgg.conv_lif3.du_regs[t+1]*(-leak*vgg.conv_lif3.du_regs[t])
        du_conv3 = ds_conv3*de_func(toy.conv_lif3.u_regs[t],th) + vgg.conv_lif3.du_regs[t+1]*leak*(1-vgg.conv_lif3.s_regs[t])
        vgg.conv_lif3.du_regs[t] += du_conv3 
        dW_conv3 = conv_weight_update(du_pool2.type(torch.float),vgg.conv_lif2.s_regs[t].type(torch.float),1) + mom*prev_conv3
        prev_conv3 = dW_conv3
#         dW_conv3 = torch.sum(dW_conv3,0)
#         dW_conv3 = dW_conv3.view(128,64,dW_conv3.shape[-1],dW_conv3.shape[-1])
        vgg.conv3.weight.data -=l_r*(dW_conv3 + lamb*torch.abs(vgg.conv3.weight))
        
        ds_conv2 = conv_dx_update(du_conv3,vgg.conv3.weight,'same')+vgg.conv_lif2.du_regs[t+1]*(-leak*vgg.conv_lif2.du_regs[t])
        du_conv2 = ds_conv2*de_func(toy.conv_lif2.u_regs[t],th) + vgg.conv_lif2.du_regs[t+1]*leak*(1-vgg.conv_lif2.s_regs[t])
        vgg.conv_lif2.du_regs[t] += du_conv2 
        dW_conv2 = conv_weight_update(du_conv2.type(torch.float),F.max_pool2d(vgg.conv_lif1.s_regs[t].type(torch.float),kernel_size=2),1) + mom*prev_conv2
        prev_conv2 = dW_conv2
#         dW_conv2 = torch.sum(dW_conv2,0)
#         dW_conv2 = dW_conv2.view(64,64,dW_conv2.shape[-1],dW_conv2.shape[-1])
        vgg.conv2.weight.data -=l_r*(dW_conv2 + lamb*torch.abs(vgg.conv2.weight))
        du_pool1 = vgg.unpool1(conv_dx_update(du_conv2,vgg.conv2.weight,'same'),vgg.pool1_ind[t])
        
        
#         ds_conv1a = du_pool1+vgg.conv_lif1a.du_regs[t+1]*(-leak*vgg.conv_lif1a.du_regs[t])
#         du_conv1a = ds_conv1a*de_func(toy.conv_lif1a.u_regs[t],th) + vgg.conv_lif1a.du_regs[t+1]*leak*(1-vgg.conv_lif1a.s_regs[t])
#         vgg.conv_lif1a.du_regs[t] += du_conv1a
#         dW_conv1a = conv_weight_update(du_pool1.type(torch.float),vgg.conv_lif1.s_regs[t].type(torch.float),1)
#         dW_conv1a = torch.sum(dW_conv1a,0)
#         dW_conv1a = dW_conv1a.view(64,32,dW_conv1a.shape[-1],dW_conv1a.shape[-1])
#         vgg.conv1a.weight.data -=l_r*dW_conv1a
        ds_conv1 = du_pool1 + vgg.conv_lif1.du_regs[t+1]*(-leak*vgg.conv_lif1.du_regs[t])
        du_conv1 = ds_conv1*de_func(toy.conv_lif1.u_regs[t],th) + vgg.conv_lif1.du_regs[t+1]*leak*(1-vgg.conv_lif1.s_regs[t])
        vgg.conv_lif1.du_regs[t] += du_conv1
        dW_conv1 = conv_weight_update(du_conv1.type(torch.float),vgg.s_regs_inp[t].type(torch.float),1) + mom*prev_conv1
        prev_conv1 = dW_conv1
#         dW_conv1 = torch.sum(dW_conv1,0)
#         dW_conv1 = dW_conv1.view(64,input_dim,dW_conv1.shape[-1],dW_conv1.shape[-1])
        vgg.conv1.weight.data -=l_r*(dW_conv1 + lamb*torch.abs(vgg.conv1.weight))
    
    return 0

In [8]:
#     du_pool2 = torch.sum(du_pool2,0)
#     du_pool2 = torch.unsqueeze(du_pool2,1)
    
    ## Update du in conv3, time T
#     d_conv3 = nn.Conv2d(128, 128, stride=1, kernel_size=f, padding=f-1, bias=False)
    
    ## Update weight in Conv3, time T
#     f = du_pool2.shape[-1]
#     d_conv3 = nn.Conv2d(128, 128, stride=1, padding=1, kernel_size=f, bias=False)
#     d_conv3.weight.data = du_pool2.type(torch.float)
#     dW_conv3 = d_conv3(vgg.conv_lif2.s_regs[-1].type(torch.float))
#     dW_conv3 = torch.sum(dW_conv3,0)
#     dW_conv3 = torch.unsqueeze(dW_conv3,1)
#     vgg.conv3.weight.data -= l_r*dW_conv3

In [9]:
class VGG_1(nn.Module):
    def __init__(self,time_step,leak):
        super(VGG_1, self).__init__()
        
        self.time_step = time_step
        self.s_regs_inp = None
        self.s_regs_conv = None
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1, bias=False)
        
#         self.deconv1 = nn.Conv2d()
        self.lif_conv1 = LIF(time_step,leak)
        self.pool1 = nn.MaxPool2d(kernel_size=2,return_indices=True)
        self.pool1_ind = []
        self.unpool1 = nn.MaxUnpool2d(kernel_size=2)

        self.fc1 = nn.Linear(16 * 14 * 14, 512, bias=False)
        self.lif_fc1 = LIF(time_step,leak)
        self.fc_out = nn.Linear(512, 10, bias=False)
        
    def forward(self, inp):

        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        self.pool1_ind = []

        u_out = 0
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp
            x = self.conv1(spike_inp)
            x = self.lif_conv1(x,t)
            x, indices = self.pool1(x)
            x= x.view(x.shape[0],-1)
            
            if t == 0:
                self.s_regs_conv = torch.zeros(self.time_step,*x.shape, device=device)
            self.pool1_ind.append(indices)
            self.s_regs_conv[t] += x
            
            x = self.fc1(x)
            x = self.lif_fc1(x,t)
            
            x = self.fc_out(x)
            u_out = u_out + x

        return u_out

In [10]:
def bp_VGG1(vgg,leak,time_step,du_out,l_r,th):
   
    ## First fc
    du_fc1 = torch.matmul(du_out,vgg.fc_out.weight)*de_func(vgg.lif_fc1.u_regs[-1],th)
    vgg.lif_fc1.du_regs[-1] += du_fc1
       
    ## Update weight
    w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[-1])
    vgg.fc1.weight.data -= l_r*w_conv_1
     
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),vgg.lif_fc1.s_regs[-1])
    vgg.fc_out.weight.data -= l_r*w_1_out
    
    dx_pool1 = torch.matmul(du_fc1,vgg.fc1.weight)
    dx_pool1 = dx_pool1.view(dx_pool1.shape[0],16,14,14)
    du_pool1 = vgg.unpool1(dx_pool1,vgg.pool1_ind[-1])
    du_pool1 = torch.sum(du_pool1,0)
    du_pool1 = torch.unsqueeze(du_pool1,1)
    f = du_pool1.shape[-1]
    d_conv1_w = nn.Conv2d(1, 16, stride=1, padding=1,kernel_size=f, bias=False)
    d_conv1_w.weight.data = du_pool1.type(torch.float)
    dW = d_conv1_w(vgg.s_regs_inp[-1].type(torch.float))
    dW = torch.sum(dW,0)
    dW = torch.unsqueeze(dW,1)

    vgg.conv1.weight.data -= l_r*dW
    
    for t in range(time_step-2,-1,-1):
        
        ds_fc1 = torch.matmul(du_out,vgg.fc_out.weight)+vgg.lif_fc1.du_regs[t+1]*(-leak*vgg.lif_fc1.du_regs[t])
        du_fc1 = (ds_fc1)*de_func(vgg.lif_fc1.du_regs[t],th) + vgg.lif_fc1.du_regs[t+1]*leak*(1-vgg.lif_fc1.s_regs[t])
        vgg.lif_fc1.du_regs[t] += du_fc1
        
        w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[t])
        vgg.fc1.weight.data -= l_r*w_conv_1
        
        
        dx_pool1 = torch.matmul(du_fc1,vgg.fc1.weight)
        dx_pool1 = dx_pool1.view(dx_pool1.shape[0],16,14,14)
        du_pool1 = vgg.unpool1(dx_pool1,vgg.pool1_ind[t])
        du_pool1 = torch.sum(du_pool1,0)
        du_pool1 = torch.unsqueeze(du_pool1,1)
        f = du_pool1.shape[-1]
        d_conv1_w = nn.Conv2d(1, 16, stride=1, padding=1,kernel_size=f, bias=False)
        d_conv1_w.weight.data = du_pool1.type(torch.float)
        dW = d_conv1_w(vgg.s_regs_inp[t].type(torch.float))
        dW = torch.sum(dW,0)
        dW = torch.unsqueeze(dW,1)

        vgg.conv1.weight.data -= l_r*dW
    
    
    return 0
    
    

In [11]:
class LIF(nn.Module):
    def __init__(self, time_step,leak):
        super(LIF, self).__init__()
        
        self.u_regs = None
        self.du_regs = None
        self.s_regs = None
        self.leak = leak
        self.time_step = time_step
        self.thresh = 0.5
        
    def forward(self,inp,t):
        
#         print("memory before clear",torch.cuda.memory_allocated())
        if t == 0:
            size = inp.shape
            self.u_regs = torch.zeros(self.time_step,*size, device=device)
            self.du_regs = torch.zeros(self.time_step,*size, device=device)
#             err = torch.normal(0, 0.1,(1,1)).cuda()
#             inp = inp + err
#             self.u_regs[0] = quant(inp,2**4)
            self.u_regs[0] = inp
            self.s_regs = torch.zeros(self.time_step,*size, device=device)

            spike = inp.gt(self.thresh).float()

            self.s_regs[0] = spike
            
        else:
#             err = torch.normal(0, 0.1,(1,1))
#             inp = inp + err
#             self.u_regs[t] = quant(self.leak * self.u_regs[t-1] * (1 - self.s_regs[t-1]) + (1-self.leak)*inp, 2**4)
            self.u_regs[t] = self.leak*self.u_regs[t-1]*(1-self.s_regs[t-1]) + inp

            spike = self.u_regs[t].gt(self.thresh).float()

            self.s_regs[t] = spike
            
#         print("memory after clear",torch.cuda.memory_allocated())
#         torch.cuda.empty_cache()
#         gc.collect()
        return spike


In [12]:
### Back propagation for MLP
def bp_MLP(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## First fc
    du_fc1 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
#     toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
    toy.fc_1.weight.data -= l_r*w_inp_1
    
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[-1])
#     toy.fc_out.weight.data -= l_r*quant(w_1_out,2**4)
    toy.fc_out.weight.data -= l_r*w_1_out

    for t in range(time_step-2,-1,-1):
        
        ## First fc
        ds_fc1 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
#         toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
#         print("du_size",du_fc1.shape)
#         print("s_size",s_regs_inp[t].shape)
#         print("dweight shape",w_inp_1.shape)
        toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[t])
#         toy.fc_out.weight.data -= l_r*quant(w_1_out,2**4)
        toy.fc_out.weight.data -= l_r*w_1_out


In [13]:
### Back propagation
def bp(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## Second fc    
    du_fc2 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif2.u_regs[-1],th)    
    toy.lif2.du_regs[-1] = toy.lif2.du_regs[-1] + du_fc2
    
    ## First fc
    du_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    
    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
    #toy.fc_1.weight.data -= l_r*w_inp_1

    w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[-1])
    toy.fc_2.weight.data -= l_r*quant(w_1_2,2**4)
    #toy.fc_2.weight.data -= l_r*w_1_2

    w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[-1])
    toy.fc_out.weight.data -= l_r*quant(w_2_out,2**4)
    #toy.fc_out.weight.data -= l_r*w_2_out

    for t in range(time_step-2,-1,-1):

        ds_fc2 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif2.du_regs[t+1]*(-leak*toy.lif2.u_regs[t])
        du_fc2 = (ds_fc2)*de_func(toy.lif2.u_regs[t],th) + toy.lif2.du_regs[t+1]*leak*(1-toy.lif2.s_regs[t])
        toy.lif2.du_regs[t] += du_fc2
        
        ## First fc
        ds_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
        
        #toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[t])
        toy.fc_2.weight.data -= l_r*quant(w_1_2,2**4)
        #toy.fc_2.weight.data -= l_r*w_1_2

        w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[t])
        toy.fc_out.weight.data -= l_r*quant(w_2_out,2**4)
        #toy.fc_out.weight.data -= l_r*w_2_out


In [14]:
import torch
import torchvision

batch_size_train = 128
batch_size_test = 1000

train_loader_mnist = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./mnist', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader_mnist = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./mnist', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [15]:
# import urllib.request
# import ssl

# ssl._create_default_https_context = ssl._create_unverified_context
# response = urllib.request.urlopen('https://www.python.org')
# print(response.read().decode('utf-8'))

In [16]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32

trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True,
                                        download=True, transform=transform)
train_loader_cifar10 = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False,
                                       download=True, transform=transform)
test_loader_cifar10 = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [17]:
test_losses = []
log_interval = 10

In [18]:
def check_sparsity(toy):
    return 0

In [19]:
time_step = 8
leak = 0.99

data = "cifar10"
toy = VGG_5(time_step,leak,data).cuda()

if data == "cifar10":
    train_loader = train_loader_cifar10
    input_dim = 3
    pre_linear_dim = 8
    test_loader = test_loader_cifar10
elif data == "mnist":
    train_loader = train_loader_mnist
    input_dim = 1
    pre_linear_dim = 7
    test_loader = test_loader_mnist
# vgg = VGG_5(time_step)
# vgg =vgg.cuda()
# print("weight",toy.fc_1.weight)
# torch.nn.init.normal_(toy.fc_1.weight, mean=0.0, std=0.1)
# toy.fc_1.weight.data = quant(toy.fc_1.weight,2**4)
# torch.nn.init.normal_(toy.fc_2.weight, mean=0.0, std=0.1)
# toy.fc_2.weight.data = quant(toy.fc_2.weight,2**4)
# torch.nn.init.normal_(toy.fc_out.weight, mean=0.0, std=0.1)
# toy.fc_out.weight.data = quant(toy.fc_out.weight,2**4)
# print("quantized weight",toy.fc_1.weight)
lr = 0.001
mom = 0.9
lamb = 0.0001
loss = nn.CrossEntropyLoss()

# test(toy,data,test_loader)
running_loss = 0.0
for epoch in range(15):
#     if epoch > 2:
#         lr = 0.004
#     if epoch > 4:
#         lr = 0.002
#     if epoch>8:
#         lr = 0.001
    for batch_idx, (data, target) in enumerate(train_loader):
        # print(torch.mean(data))
        with torch.no_grad():
            data = data.cuda()
            target = target.cuda()
            out = toy(data)
#             print("memory after fwd",torch.cuda.memory_allocated()/10000000)
        out = Variable(out,requires_grad=True)

        # err = loss(out,target,reduction='sum')
#         L2_reg = torch.tensor(0.,requires_grad=True)
#         for name, param in toy.named_parameters():
#             if 'weight' in name:
#                 L2_reg = L2_reg + torch.sum(param.pow(2)) / 2
        err = F.cross_entropy(out, target,reduction='mean')
        err.backward()

        # exp = torch.exp(out)
        # exp_sum = torch.sum(torch.exp(out),1, keepdim=True)   
        # target = F.one_hot(target, num_classes=10)
        # #L = -1*torch.sum((target*torch.log((exp/exp_sum))),1, keepdim=True)
        # du_out = exp/exp_sum
        # du_out = (du_out - target)/batch_size_train
        du_out = out.grad

        bp_VGG5(toy,leak,time_step,du_out,lr,mom,lamb,0.5,pre_linear_dim,input_dim)
#             print("memory after bp",torch.cuda.memory_allocated()/10000000)

        # bp_MLP(toy,leak,time_step,du_out,toy.s_regs_inp,lr,0.5)



        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), err.item()))
            
        # print statistics
        # running_loss += err.item()
        # if i % 2000 == 1999:    # print every 2000 mini-batches
        #     print('[%d, %5d] loss: %.3f' %
        #           (epoch + 1, i + 1, running_loss / 2000))
        #     running_loss = 0.0

#             del toy.lif_conv1.s_regs
#             del toy.lif_conv1.u_regs
#             del toy.lif_conv1.du_regs
#             del toy.lif_fc1.du_regs
#             del toy.lif_fc1.u_regs
#             del toy.lif_fc1.s_regs
#             del toy.s_regs_conv
#             del toy.s_regs_inp
#             del data
#             del target
#             torch.cuda.empty_cache()

#             gc.collect()
#             print("memory after clear",torch.cuda.memory_allocated()/10000000)

    test(toy,data,test_loader)




KeyboardInterrupt: 

In [None]:
model = VGG_5(1,0.5,"cifar10").cuda()


In [None]:
# class test(nn.Module):
#     def __init__(self,time_step, leak, data):
#         super(test, self).__init__()
        
#         if data == "cifar10":
#             input_dim = 3
#             pre_linear_dim = 8
#         elif data == "mnist":
#             input_dim = 1
#             pre_linear_dim = 7
            
#         self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False)
#         self.conv_lif1 = LIF(time_step, leak)
#         self.pool1 = nn.MaxPool2d(kernel_size=2,return_indices=True)
#         self.pool1_ind = []
#         self.unpool1 = nn.MaxUnpool2d(kernel_size=2)
        
#         self.fc1 = nn.Linear(128 * pre_linear_dim * pre_linear_dim, 1024, bias=False)
#         self.fc_lif1 = LIF(time_step,leak)
#         self.fc_out = nn.Linear(1024, 10, bias=False)
        
        
#         def forward(self, inp):

#         size = inp.shape
#         self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
#         self.pool1_ind = []
#         u_out = 0
        
#         for t in range(self.time_step):
#             spike_inp = PoissonGen(inp)
#             self.s_regs_inp[t] += spike_inp 
#             x = self.conv1(spike_inp)
#             x = self.conv_lif1(x,t)
#             x,indices = self.pool1(x)
#             self.pool1_ind.append(indices)
#             x = x.view(x.shape[0],-1)
            
#             if t == 0:
#                 self.s_regs_conv = torch.zeros(self.time_step,*x.shape, device=device)
#             self.s_regs_conv[t] += x
            
#             x = self.fc1(x)
#             x = self.fc_lif1(x,t)
#             x = self.fc_out(x)
#             u_out = u_out + x
#         return u_out

In [347]:
a = torch.rand(4,1,4,4,device=device,requires_grad=True)
w = torch.randn(2,1,3,3,device=device,requires_grad=True)
conv1 = F.conv2d(a,w,padding=1)
# conv1 = Variable(conv1,requires_grad=True)
pool1 = nn.MaxPool2d(kernel_size=2, return_indices=True)
unpool1 = nn.MaxUnpool2d(kernel_size=2)
x,indices = pool1(conv1)
# x = Variable(x,requires_grad=True)
err = torch.rand_like(x)

In [348]:
x.backward(err)

In [352]:
du_pool = unpool1(err,indices)
# dw0 = conv_weight_update(torch.unsqueeze(du_pool[0],0),torch.unsqueeze(a[0],0),1)
# dw1 = conv_weight_update(torch.unsqueeze(du_pool[1],0),torch.unsqueeze(a[1],0),1)
# dw2 = conv_weight_update(torch.unsqueeze(du_pool[2],0),torch.unsqueeze(a[2],0),1)
# dw3 = conv_weight_update(torch.unsqueeze(du_pool[3],0),torch.unsqueeze(a[3],0),1)
# dw = (dw0+dw1+dw2+dw3)
# dww = conv_weight_update(du_pool,a,1)
dx = conv_dx_update(du_pool,w,1)

In [346]:
du_pool.shape

torch.Size([4, 2, 4, 4])

In [354]:
dw

tensor([[[[2.6056, 3.2645, 3.6455],
          [2.7099, 1.5400, 3.5414],
          [3.5173, 3.7435, 3.5242]]],


        [[[3.1123, 3.0722, 2.8452],
          [3.2613, 3.9904, 6.9552],
          [2.5672, 3.7917, 4.1968]]]], device='cuda:0', grad_fn=<AddBackward0>)

In [353]:
w.grad

tensor([[[[2.6056, 3.2645, 3.6455],
          [2.7099, 1.5400, 3.5414],
          [3.5173, 3.7435, 3.5242]]],


        [[[3.1123, 3.0722, 2.8452],
          [3.2613, 3.9904, 6.9552],
          [2.5672, 3.7917, 4.1968]]]], device='cuda:0')

In [355]:
torch.sum(dw-w.grad)

tensor(2.3842e-07, device='cuda:0', grad_fn=<SumBackward0>)

In [304]:
torch.sum(dx-a.grad)

tensor(1.4342e-07, device='cuda:0', grad_fn=<SumBackward0>)

In [327]:
c = torch.ones(2,2)

In [329]:
c[1] = c[1]/5

In [357]:
a = torch.rand(4,2,4,4,device=device,requires_grad=True)

In [360]:
a.shape

torch.Size([4, 2, 4, 4])

In [361]:
a.view(1,8,4,4).shape

torch.Size([1, 8, 4, 4])

In [None]:
a.view