In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import gc

In [3]:
#device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp)
    return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp))

# def spike_function(x):
#     x[x>0] = 1
#     x[x<=0] = 0
#     return x

def de_func(U,th):
    alpha = 0.3
    U = alpha*(1.0 - abs((U-th)/th))
    U[U<0]=0
    return U

def test(toy):
    test_loss = 0
    correct = 0
    toy = toy.cuda()
    for data, target in test_loader:
        data = data.cuda()
        target = target.cuda()
        output = toy(data)
        test_loss +=F.cross_entropy(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def quant(input, k):
    size = input.size()
    #mean = torch.mean(input.abs(), 1, keepdim=True)
    x = input
    #print(x)
    xmax = x.abs().max()
    num_bits=k
    v0 = 1
    v1 = 2
    v2 = -0.5
    y = k #2.**num_bits - 1.
    #print(y)
    x = x.add(v0).div(v1)
    #print(x)
    x = x.mul(y).round_()
    #print(x)
    x = x.div(y)
    #print(x)
    x = x.add(v2)
    #print(x)
    x = x.mul(v1)
    #print(x)
    input = x
    return input

In [5]:
class model(nn.Module):
    def __init__(self, time_step,leak):
        super(model, self).__init__()
        
       
        self.fc_1 = nn.Linear(28*28,256,bias=False)
        self.fc_2 = nn.Linear(256,256,bias=False)
        self.fc_out = nn.Linear(256,10,bias=False)
        
        self.lif1 = LIF(time_step,leak)
        self.lif2 = LIF(time_step,leak)
        self.time_step = time_step
        self.s_regs_inp = None
        
    def forward(self, inp):
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        
        for t in range(self.time_step):
            
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            
            x = self.fc_1(spike_inp)
            #x = quant(x,2**4)
            x = self.lif1(x, t)
            x = self.fc_2(x)
            #x = quant(x,2**4)
            x = self.lif2(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out/self.time_step

In [6]:
class MLP(nn.Module):
    def __init__(self,time_step,leak):
        super(MLP, self).__init__()
        
        self.fc_1 = nn.Linear(28*28,512,bias=False)
        self.fc_out = nn.Linear(512,10,bias=False)
        self.lif1 = LIF(time_step,leak)
        self.time_step = time_step
        self.s_regs_inp = None
        
    def forward(self, inp):
#         print("size is:", (inp.view(inp.shape[0],1,28,28)).shape)
        inp = inp.view(inp.shape[0],-1)
        size = inp.shape
        
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        u_out = 0
        
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.fc_1(spike_inp)
            #x = quant(x,2**4)
            x = self.lif1(x, t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out
        

In [7]:
class VGG_5(nn.Module):
    def __init__(self,time_step, leak):
        super(VGG_5, self).__init__()
        
        self.time_step = time_step
        self.s_regs_inp = None
        self.s_regs_conv = None
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        self.conv_lif1 = LIF(time_step, leak)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.pool1_ind = []
        self.unpool1 = nn.MaxUnpool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
        self.conv_lif2 = LIF(time_step, leak)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False)
        self.conv_lif3 = LIF(time_step, leak)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.pool2_ind = []
        self.unpool2 = nn.MaxUnpool2d(kernel_size=2)

        self.fc1 = nn.Linear(128 * 7 * 7, 1024, bias=False)
        self.fc_lif1 = LIF(time_step)
        self.fc_out = nn.Linear(1024, 10, bias=False)
        
    def forward(self, inp):

        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        self.pool1_ind = []
        self.pool2_ind = []
        u_out = 0
        
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp 
            x = self.conv1(spike_inp)
            x = self.conv_lif1(x,t)
            x, indices = self.pool1(x)
            self.pool1_ind.append(indices)
            x = self.conv2(x)
            x = self.conv_lif2(x,t)
            x = self.conv3(x)
            x = self.conv_lif3(x,t)
            x, indices = self.pool2(x)
            x = x.view(x.shape[0],-1)
            
            if t == 0:
                self.s_regs_conv = torch.zeros(self.time_step,*x.shape, device=device)
            self.pool2_ind.append(indices)
            self.s_regs_conv[t] += x
            
            x = self.fc1(x)
            x = self.fc_lif1(x,t)
            x = self.fc_out(x)
            u_out = u_out + x
        return u_out

In [8]:
def bp_VGG5(vgg,leak,time_step,du_out,l_r,th):
   
    ## Update weight in FCs, time Ts
    du_fc1 = torch.matmul(du_out,vgg.fc_out.weight)*de_func(vgg.fc_lif1.u_regs[-1],th)
    vgg.fc_lif1.du_regs[-1] += du_fc1
    w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[-1])
    vgg.fc1.weight.data -= l_r*w_conv_1   
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),vgg.fc_lif1.s_regs[-1])
    vgg.fc_out.weight.data -= l_r*w_1_out
    
    ## Update du in pool2, time T
    dx_pool2 = torch.matmul(du_fc1,vgg.fc1.weight)
    dx_pool2 = dx_pool2.view(dx_pool2.shape[0],128,7,7)
    du_pool2 = vgg.unpool2(dx_pool2,vgg.pool2_ind[-1])
    du_pool2 = torch.sum(du_pool2,0)
    du_pool2 = torch.unsqueeze(du_pool2,1)
    
    ## Update du in conv3, time T
    d_conv3 = nn.Conv2d(128, 128, stride=1, kernel_size=f, padding=f-1, bias=False)
    
    ## Update weight in Conv3, time T
    f = du_pool2.shape[-1]
    d_conv3 = nn.Conv2d(128, 128, stride=1, padding=1, kernel_size=f, bias=False)
    d_conv3.weight.data = du_pool2.type(torch.float)
    dW_conv3 = d_conv3(vgg.conv_lif2.s_regs[-1].type(torch.float))
    dW_conv3 = torch.sum(dW_conv3,0)
    dW_conv3 = torch.unsqueeze(dW_conv3,1)
    vgg.conv3.weight.data -= l_r*dW_conv3
    
    
    
    for t in range(time_step-2,-1,-1):
        
        ds_fc1 = torch.matmul(du_out,vgg.fc_out.weight)+vgg.lif_fc1.du_regs[t+1]*(-leak*vgg.lif_fc1.du_regs[t])
        du_fc1 = (ds_fc1)*de_func(vgg.lif_fc1.du_regs[t],th) + vgg.lif_fc1.du_regs[t+1]*leak*(1-vgg.lif_fc1.s_regs[t])
        vgg.lif_fc1.du_regs[t] += du_fc1
        
        w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[t])
        vgg.fc1.weight.data -= l_r*w_conv_1
        
        
        dx_pool1 = torch.matmul(du_fc1,vgg.fc1.weight)
        dx_pool1 = dx_pool1.view(dx_pool1.shape[0],16,14,14)
        du_pool1 = vgg.unpool1(dx_pool1,vgg.pool1_ind[t])
        du_pool1 = torch.sum(du_pool1,0)
        du_pool1 = torch.unsqueeze(du_pool1,1)
        f = du_pool1.shape[-1]
        d_conv1_w = nn.Conv2d(1, 16, stride=1, padding=1,kernel_size=f, bias=False)
        d_conv1_w.weight.data = du_pool1.type(torch.float)
        dW = d_conv1_w(vgg.s_regs_inp[t].type(torch.float))
        dW = torch.sum(dW,0)
        dW = torch.unsqueeze(dW,1)

        vgg.conv1.weight.data -= l_r*dW
    
    
    return 0

In [9]:
class VGG_1(nn.Module):
    def __init__(self,time_step,leak):
        super(VGG_1, self).__init__()
        
        self.time_step = time_step
        self.s_regs_inp = None
        self.s_regs_conv = None
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1, bias=False)
        
#         self.deconv1 = nn.Conv2d()
        self.lif_conv1 = LIF(time_step,leak)
        self.pool1 = nn.MaxPool2d(kernel_size=2,return_indices=True)
        self.pool1_ind = []
        self.unpool1 = nn.MaxUnpool2d(kernel_size=2)

        self.fc1 = nn.Linear(16 * 14 * 14, 512, bias=False)
        self.lif_fc1 = LIF(time_step,leak)
        self.fc_out = nn.Linear(512, 10, bias=False)
        
    def forward(self, inp):

        size = inp.shape
        self.s_regs_inp = torch.zeros(self.time_step,*size, device=device)
        self.pool1_ind = []

        u_out = 0
        for t in range(self.time_step):
            spike_inp = PoissonGen(inp)
            self.s_regs_inp[t] += spike_inp
            x = self.conv1(spike_inp)
            x = self.lif_conv1(x,t)
            x, indices = self.pool1(x)
            x= x.view(x.shape[0],-1)
            
            if t == 0:
                self.s_regs_conv = torch.zeros(self.time_step,*x.shape, device=device)
            self.pool1_ind.append(indices)
            self.s_regs_conv[t] += x
            
            x = self.fc1(x)
            x = self.lif_fc1(x,t)
            
            x = self.fc_out(x)
            u_out = u_out + x

        return u_out

In [10]:
def bp_VGG1(vgg,leak,time_step,du_out,l_r,th):
   
    ## First fc
    du_fc1 = torch.matmul(du_out,vgg.fc_out.weight)*de_func(vgg.lif_fc1.u_regs[-1],th)
    vgg.lif_fc1.du_regs[-1] += du_fc1
       
    ## Update weight
    w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[-1])
    vgg.fc1.weight.data -= l_r*w_conv_1
     
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),vgg.lif_fc1.s_regs[-1])
    vgg.fc_out.weight.data -= l_r*w_1_out
    
    dx_pool1 = torch.matmul(du_fc1,vgg.fc1.weight)
    dx_pool1 = dx_pool1.view(dx_pool1.shape[0],16,14,14)
    du_pool1 = vgg.unpool1(dx_pool1,vgg.pool1_ind[-1])
    du_pool1 = torch.sum(du_pool1,0)
    du_pool1 = torch.unsqueeze(du_pool1,1)
    f = du_pool1.shape[-1]
    d_conv1_w = nn.Conv2d(1, 16, stride=1, padding=1,kernel_size=f, bias=False)
    d_conv1_w.weight.data = du_pool1.type(torch.float)
    dW = d_conv1_w(vgg.s_regs_inp[-1].type(torch.float))
    dW = torch.sum(dW,0)
    dW = torch.unsqueeze(dW,1)

    vgg.conv1.weight.data -= l_r*dW
    
    for t in range(time_step-2,-1,-1):
        
        ds_fc1 = torch.matmul(du_out,vgg.fc_out.weight)+vgg.lif_fc1.du_regs[t+1]*(-leak*vgg.lif_fc1.du_regs[t])
        du_fc1 = (ds_fc1)*de_func(vgg.lif_fc1.du_regs[t],th) + vgg.lif_fc1.du_regs[t+1]*leak*(1-vgg.lif_fc1.s_regs[t])
        vgg.lif_fc1.du_regs[t] += du_fc1
        
        w_conv_1 = torch.matmul(torch.transpose(du_fc1,0,1),vgg.s_regs_conv[t])
        vgg.fc1.weight.data -= l_r*w_conv_1
        
        
        dx_pool1 = torch.matmul(du_fc1,vgg.fc1.weight)
        dx_pool1 = dx_pool1.view(dx_pool1.shape[0],16,14,14)
        du_pool1 = vgg.unpool1(dx_pool1,vgg.pool1_ind[t])
        du_pool1 = torch.sum(du_pool1,0)
        du_pool1 = torch.unsqueeze(du_pool1,1)
        f = du_pool1.shape[-1]
        d_conv1_w = nn.Conv2d(1, 16, stride=1, padding=1,kernel_size=f, bias=False)
        d_conv1_w.weight.data = du_pool1.type(torch.float)
        dW = d_conv1_w(vgg.s_regs_inp[t].type(torch.float))
        dW = torch.sum(dW,0)
        dW = torch.unsqueeze(dW,1)

        vgg.conv1.weight.data -= l_r*dW
    
    
    return 0
    
    

In [11]:
class LIF(nn.Module):
    def __init__(self, time_step,leak):
        super(LIF, self).__init__()
        
        self.u_regs = None
        self.du_regs = None
        self.s_regs = None
        self.leak = leak
        self.time_step = time_step
        self.thresh = 0.5
        
    def forward(self,inp,t):
        
#         print("memory before clear",torch.cuda.memory_allocated())
        if t == 0:
            size = inp.shape
            self.u_regs = torch.zeros(self.time_step,*size, device=device)
            self.du_regs = torch.zeros(self.time_step,*size, device=device)
#             err = torch.normal(0, 0.1,(1,1)).cuda()
#             inp = inp + err
#             self.u_regs[0] = quant(inp,2**4)
            self.u_regs[0] = inp
            self.s_regs = torch.zeros(self.time_step,*size, device=device)

            spike = inp.gt(self.thresh).float()

            self.s_regs[0] = spike
            
        else:
#             err = torch.normal(0, 0.1,(1,1))
#             inp = inp + err
#             self.u_regs[t] = quant(self.leak * self.u_regs[t-1] * (1 - self.s_regs[t-1]) + (1-self.leak)*inp, 2**4)
            self.u_regs[t] = self.leak*self.u_regs[t-1]*(1-self.s_regs[t-1]) + inp

            spike = self.u_regs[t].gt(self.thresh).float()

            self.s_regs[t] = spike
            
#         print("memory after clear",torch.cuda.memory_allocated())
#         torch.cuda.empty_cache()
#         gc.collect()
        return spike


In [12]:
### Back propagation for MLP
def bp_MLP(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## First fc
    du_fc1 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
#     toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
    toy.fc_1.weight.data -= l_r*w_inp_1
    
    w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[-1])
#     toy.fc_out.weight.data -= l_r*quant(w_1_out,2**4)
    toy.fc_out.weight.data -= l_r*w_1_out

    for t in range(time_step-2,-1,-1):
        
        ## First fc
        ds_fc1 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
#         toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
#         print("du_size",du_fc1.shape)
#         print("s_size",s_regs_inp[t].shape)
#         print("dweight shape",w_inp_1.shape)
        toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif1.s_regs[t])
#         toy.fc_out.weight.data -= l_r*quant(w_1_out,2**4)
        toy.fc_out.weight.data -= l_r*w_1_out


In [13]:
### Back propagation
def bp(toy,leak,time_step,du_out,s_regs_inp,l_r,th):
    
    ## Second fc    
    du_fc2 = torch.matmul(du_out,toy.fc_out.weight)*de_func(toy.lif2.u_regs[-1],th)    
    toy.lif2.du_regs[-1] = toy.lif2.du_regs[-1] + du_fc2
    
    ## First fc
    du_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)*de_func(toy.lif1.u_regs[-1],th)
    toy.lif1.du_regs[-1] += du_fc1

    
    ## Update weight
    w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[-1])
    toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
    #toy.fc_1.weight.data -= l_r*w_inp_1

    w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[-1])
    toy.fc_2.weight.data -= l_r*quant(w_1_2,2**4)
    #toy.fc_2.weight.data -= l_r*w_1_2

    w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[-1])
    toy.fc_out.weight.data -= l_r*quant(w_2_out,2**4)
    #toy.fc_out.weight.data -= l_r*w_2_out

    for t in range(time_step-2,-1,-1):

        ds_fc2 = torch.matmul(du_out,toy.fc_out.weight)+toy.lif2.du_regs[t+1]*(-leak*toy.lif2.u_regs[t])
        du_fc2 = (ds_fc2)*de_func(toy.lif2.u_regs[t],th) + toy.lif2.du_regs[t+1]*leak*(1-toy.lif2.s_regs[t])
        toy.lif2.du_regs[t] += du_fc2
        
        ## First fc
        ds_fc1 = torch.matmul(du_fc2,toy.fc_2.weight)+toy.lif1.du_regs[t+1]*(-leak*toy.lif1.u_regs[t])
        du_fc1 = (ds_fc1)*de_func(toy.lif1.u_regs[t],th) + toy.lif1.du_regs[t+1]*leak*(1-toy.lif1.s_regs[t])
        toy.lif1.du_regs[t] += du_fc1

        ## Update weight
        w_inp_1 = torch.matmul(torch.transpose(du_fc1,0,1),s_regs_inp[t])
        toy.fc_1.weight.data -= l_r*quant(w_inp_1,2**4)
        
        #toy.fc_1.weight.data -= l_r*w_inp_1

        w_1_2 = torch.matmul(torch.transpose(du_fc2,0,1),toy.lif1.s_regs[t])
        toy.fc_2.weight.data -= l_r*quant(w_1_2,2**4)
        #toy.fc_2.weight.data -= l_r*w_1_2

        w_2_out = torch.matmul(torch.transpose(du_out,0,1),toy.lif2.s_regs[t])
        toy.fc_out.weight.data -= l_r*quant(w_2_out,2**4)
        #toy.fc_out.weight.data -= l_r*w_2_out


In [14]:
import torch
import torchvision

batch_size_train = 64
batch_size_test = 100

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [15]:
test_losses = []
train_losses = []
train_counter = []
test_counter = [i*len(train_loader.dataset) for i in range(5 + 1)]
log_interval = 10

In [15]:
time_step = 20
leak = 0.99
toy = VGG_1(time_step,leak).cuda()
# vgg = VGG_5(time_step)
# vgg =vgg.cuda()
# print("weight",toy.fc_1.weight)
# torch.nn.init.normal_(toy.fc_1.weight, mean=0.0, std=0.1)
# toy.fc_1.weight.data = quant(toy.fc_1.weight,2**4)
# torch.nn.init.normal_(toy.fc_2.weight, mean=0.0, std=0.1)
# toy.fc_2.weight.data = quant(toy.fc_2.weight,2**4)
# torch.nn.init.normal_(toy.fc_out.weight, mean=0.0, std=0.1)
# toy.fc_out.weight.data = quant(toy.fc_out.weight,2**4)
# print("quantized weight",toy.fc_1.weight)
lr = 0.02
loss = nn.CrossEntropyLoss()

test(toy)
with torch.no_grad():
    for epoch in range(12):
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.cuda()
            target = target.cuda()
            out = toy(data)
#             print("memory after fwd",torch.cuda.memory_allocated()/10000000)

            err = loss(out,target)

            exp = torch.exp(out)
            exp_sum = torch.sum(torch.exp(out),1, keepdim=True)   
            target = F.one_hot(target, num_classes=10)
            #L = -1*torch.sum((target*torch.log((exp/exp_sum))),1, keepdim=True)
            du_out = exp/exp_sum
            du_out = (du_out - target)/batch_size_train



    #         vgg_out = vgg(data)
    #         exp_vgg = torch.exp(vgg_out)
    #         exp_sum_vgg = torch.sum(torch.exp(vgg_out),1, keepdim=True)
    #         du_out_vgg = exp_vgg/exp_sum_vgg
    #         du_out_vgg = du_out_vgg - target
    #         print(du_out_vgg)



            bp_VGG1(toy,leak,time_step,du_out,lr,toy.lif_conv1.thresh)
#             print("memory after bp",torch.cuda.memory_allocated()/10000000)

    #         bp_MLP(toy,leak,time_step,du_out,toy.s_regs_inp,lr,toy.lif1.thresh)



            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), err.item()))

#             del toy.lif_conv1.s_regs
#             del toy.lif_conv1.u_regs
#             del toy.lif_conv1.du_regs
#             del toy.lif_fc1.du_regs
#             del toy.lif_fc1.u_regs
#             del toy.lif_fc1.s_regs
#             del toy.s_regs_conv
#             del toy.s_regs_inp
#             del data
#             del target
#             torch.cuda.empty_cache()
            
#             gc.collect()
#             print("memory after clear",torch.cuda.memory_allocated()/10000000)

        test(toy)





Test set: Avg. loss: 3.6960, Accuracy: 888/10000 (9%)


Test set: Avg. loss: 0.2411, Accuracy: 9258/10000 (93%)



KeyboardInterrupt: 

In [None]:
# a = torch.rand(16,1,3,3)
# b = torch.rand(10,16,28,28)
# a[:,:][:,:].shape

In [None]:
# import matplotlib.pyplot as plt
# fig = plt.figure()
# plt.plot(train_counter, train_losses, color='blue')
# plt.scatter(test_counter, test_losses, color='red')
# plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
# plt.xlabel('number of training examples seen')
# plt.ylabel('negative log likelihood loss')

In [51]:
# W = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype = torch.float)
# f = W.shape[-1]


# n_W = dH.shape[-1]
# n_H = dH.shape[-2]
# X = torch.ones(5,5)
# dX = torch.zeros(X.shape)
# dW = torch.zeros(W.shape)

torch.Size([10, 16, 28, 28])

In [52]:
# for h in range(n_H):
#     for w in range(n_W):
#         dX[h:h+f, w:w+f] += W * dH[h][w]
#         dW += X[h:h+f, w:w+f] * dH[h][w]

In [53]:
# dW

In [54]:
# dX

In [55]:
# W

In [17]:

# d_conv1 = nn.Conv2d(16, 16, stride=1, kernel_size=3, padding=1, bias=False)
d_conv2 = nn.Conv2d(16, 16, stride=1, kernel_size=3, padding =1, bias=False)

In [98]:
# print(d_conv1.weight)

In [99]:
W = torch.ones(16,16,3,3,dtype =torch.float)
dH = torch.ones(10,16,7,7,dtype =torch.float)
X = torch.ones(10,16,7,7,dtype =torch.float)

# dH = torch.sum(dH,0)
# dH = torch.unsqueeze(dH,1)
# dH.shape

# dH = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
# W = torch.unsqueeze(W,0)
# W = torch.unsqueeze(W,1)
# dH = torch.unsqueeze(dH,0)
# dH = torch.unsqueeze(dH,1)
# print(W.shape)
# print(dH.shape)

In [100]:
# d_conv1.weight.data = torch.flip(W,[-1,-2])
# ddx = d_conv1(dH)
# # ddx
# ddx.shape

In [101]:
d_conv2.weight.data = dH.type(torch.float)
# print(d_conv2.weight.shape)

# # X = torch.ones(1,1,5,5)

ddx1 = d_conv2(X.type(torch.float))
ddx1.shape

torch.Size([10, 10, 3, 3])

In [23]:
daconv2 = nn.Conv2d(1, 2, stride=1, kernel_size=3, padding ='same', bias=False)
w = daconv2.weight

In [24]:
a = torch.ones(2,1,3,3)
daconv2(a).shape

torch.Size([2, 2, 3, 3])

In [1]:
torch

1.9.0


In [30]:
a = torch.ones(2,2,3,3)

In [31]:
c = nn.Conv2d(2,3,stride=1,kernel_size=2, padding = 'same', bias = False)

In [32]:
# c.weight.data = torch.ones(3,2,2,2)
c.weight.shape

torch.Size([3, 2, 2, 2])

In [29]:
c(a)

tensor([[[[8., 8., 4.],
          [8., 8., 4.],
          [4., 4., 2.]],

         [[8., 8., 4.],
          [8., 8., 4.],
          [4., 4., 2.]],

         [[8., 8., 4.],
          [8., 8., 4.],
          [4., 4., 2.]]],


        [[[8., 8., 4.],
          [8., 8., 4.],
          [4., 4., 2.]],

         [[8., 8., 4.],
          [8., 8., 4.],
          [4., 4., 2.]],

         [[8., 8., 4.],
          [8., 8., 4.],
          [4., 4., 2.]]]], grad_fn=<MkldnnConvolutionBackward>)

In [33]:
c = nn.Conv1d(2,2,kernel_size=2)

In [34]:
c.weight.shape

torch.Size([2, 2, 2])

In [9]:
a = torch.ones(2,2,3,3)
w = torch.ones(2,2,3,3)
c = torch.matmul(a,w)
c.shape

torch.Size([2, 2, 3, 3])