In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np

In [2]:
torch.ones(2,3)*1e-5

tensor([[1.0000e-05, 1.0000e-05, 1.0000e-05],
        [1.0000e-05, 1.0000e-05, 1.0000e-05]])

In [3]:
class Conv2dFunctionN(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
        # Save arguments to context to use on backward
        # WARNING : if stride, padding, dilation etc is array, this will not work properly!!!!
#         print('stride', stride)
        if weight.shape[2] == 1 :
            padding = 0
        elif weight.shape[2] == 5 :
            padding = 2
        elif weight.shape[2] == 7 :
            padding = 3
        confs = torch.from_numpy(np.array([stride, padding, dilation, groups]))
        dinput = torch.ones(input.shape) * 1e-5
        dinput = dinput.cuda() 
        input += dinput
        out = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        dout = F.conv2d(dinput, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        dout = dout.cuda()
        ctx.save_for_backward(input, out, dout, weight, bias, confs)

        # Compute Convolution
        return out
    
    @staticmethod
    def backward(ctx, grad_output):
        # Load saved tensors
        input, out, dout, weight, bias, confs = ctx.saved_variables
        confs = confs.numpy()
        stride, padding, dilation, groups= confs[0], confs[1], confs[2], confs[3]

        # Calculate Gradient
        grad_input = grad_weight = grad_bias = None
        dinput = torch.ones(input.shape) * 1e-5
        dinput = dinput.cuda()
        
        
        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
            
        if ctx.needs_input_grad[1]:
            grad_output = grad_output*(dout)
            grad_weight = torch.nn.grad.conv2d_weight(dinput, weight.shape, grad_output, stride, padding, dilation, groups)
                
        # WARNING : Bias maybe buggy, remove if it is buggy
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)


        # WARNING : Bias maybe buggy, remove if it is buggy
        if bias is not None:
            return grad_input, grad_weight, grad_bias, None, None, None, None
        else:
            return grad_input, grad_weight, None, None, None, None, None

In [4]:
# device = torch.device("cpu")
device = torch.device("cuda")
dtype = torch.float
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [5]:
# N, C_in, C_out, K_size = batch_size, 3, 12, 3
# Create random Tensors for weights.
conw1 = torch.randn(8,3,5,5, device=device, dtype=dtype, requires_grad=True)
conw2 = torch.randn(32,8,3,3, device=device, dtype=dtype, requires_grad=True)
conw3 = torch.randn(128,32,3,3, device=device, dtype=dtype, requires_grad=True)
conw4 = torch.randn(128,128,3,3, device=device, dtype=dtype, requires_grad=True)
conw5 = torch.randn(10,128,1,1, device=device, dtype=dtype, requires_grad=True)

conw1 = torch.nn.init.xavier_uniform_(conw1, gain=1.0)
conw2 = torch.nn.init.xavier_uniform_(conw2, gain=1.0)
conw3 = torch.nn.init.xavier_uniform_(conw3, gain=1.0)
conw4 = torch.nn.init.xavier_uniform_(conw4, gain=1.0)
conw5 = torch.nn.init.xavier_uniform_(conw5, gain=1.0)

# conw1g = conw1.clone().detatch(requires_grad=True)
conw1n = torch.tensor(conw1, device=device, dtype=dtype, requires_grad=True)
conw2n = torch.tensor(conw2, device=device, dtype=dtype, requires_grad=True)
conw3n = torch.tensor(conw3, device=device, dtype=dtype, requires_grad=True)
conw4n = torch.tensor(conw4, device=device, dtype=dtype, requires_grad=True)
conw5n = torch.tensor(conw5, device=device, dtype=dtype, requires_grad=True)

# for adam 
conw1a = torch.tensor(conw1, device=device, dtype=dtype, requires_grad=True)
conw2a = torch.tensor(conw2, device=device, dtype=dtype, requires_grad=True)
conw3a = torch.tensor(conw3, device=device, dtype=dtype, requires_grad=True)
conw4a = torch.tensor(conw4, device=device, dtype=dtype, requires_grad=True)
conw5a = torch.tensor(conw5, device=device, dtype=dtype, requires_grad=True)

conw1a_list = [conw1a, conw2a, conw3a, conw4a, conw5a]
conw1n_list = [conw1n, conw2n, conw3n, conw4n, conw5n]

  app.launch_new_instance()


In [6]:
class Net(nn.Module):
    def __init__(self, conv_list, conv):
        super(Net, self).__init__()
        self.conv1 = conv.apply
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = conv.apply
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = conv.apply
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = conv.apply
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = conv.apply
        self.avgpool = torch.nn.AvgPool2d((2,2) ,stride=(2,2))
        self.maxpool = torch.nn.MaxPool2d((2,2), stride=(2,2))
        self.linear = torch.nn.Linear(128, 10)
        self.act = torch.nn.ReLU()
        self.tanh = torch.nn.Tanh()
        self.device = torch.device("cuda")
        self.dtype = torch.float
        self.c1 = None
        self.c2 = None
        self.c3 = None
        self.c4 = None
        self.c5 = None
        
        # weights of batch norm for custom backward, normal case 
        self.conw1, self.conw2, self.conw3, self.conw4, self.conw5 = conv_list 

    def forward(self, x):
        self.c1 = self.conv1(x, self.conw1)
        x = self.bn1(self.act(self.c1))
        x = self.maxpool(x)
        self.c2 = self.conv2(x, self.conw2)
        x = self.bn2(self.act(self.c2))
        x = self.maxpool(x)
        self.c3 = self.conv3(x, self.conw3)
        x = self.bn3(self.act(self.c3))
        x = self.maxpool(x)
        self.c4 = self.conv4(x, self.conw4)
        x = self.bn4(self.act(self.c4))
        x = self.maxpool(x)
        self.c5 = self.conv5(x, self.conw5)
        x = self.avgpool(self.c5)
        x = torch.squeeze(x)
#         x = self.linear(x)
        x = torch.nn.Softmax(dim=1)(x)
#         x = torch.sigmoid(x)

#         self.c1 = self.act(self.conv1(x, self.conw1))
#         x = self.maxpool(self.c1)
#         self.c2 = self.act(self.conv2(x, self.conw2))
#         x = self.maxpool(self.c2)
#         self.c3 = self.act(self.conv3(x, self.conw3))
#         x = self.maxpool(self.c3)
#         self.c4 = self.act(self.conv4(x, self.conw4))
#         x = self.maxpool(self.c4)
#         self.c5 = self.conv5(x, self.conw5)
#         x = self.avgpool(self.c5)
#         x = torch.squeeze(x)
# #         x = self.linear(x)
#         x = torch.nn.Softmax(dim=1)(x)
# #         x = torch.sigmoid(x)
        
        return x
    
    def conv_return (self) :
        return self.c1, self.c2, self.c3, self.c4, self.c5        
     

In [7]:
net = Net(conw1n_list, Conv2dFunctionN).to(device)
criterion = nn.CrossEntropyLoss().to(device)

In [8]:
image, labels = iter(trainloader).next()
outputs = net(image.to(device)).to(device)
print(outputs.shape)
loss = criterion(outputs, labels.to(device))
loss.backward()
print(torch.nn.Softmax(dim=1)(outputs).sum(dim=1))
# outputs

torch.Size([64, 10])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], device='cuda:0', grad_fn=<SumBackward1>)




In [9]:
def test (model) :
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            outputs = model(inputs.to(device))
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))
    return (100 * correct / total)    

In [10]:
NUM_EPOCH = 10
lr_list = [ 0.01, 0.01, 0.01, 0.01, 0.005,
           0.005, 0.005, 0.005, 0.001, 0.001, 0.001, 0.0005, 0.0005]

In [11]:

print('****************** Numerical Gardient update ****************')
# lr_list = [0.05] * NUM_EPOCH
criterion = nn.CrossEntropyLoss().to(device)

gated1_loss = []
gated1_accuracy = []
# optimizer = optim.Adam([conw1g, conw2g, conw3g, conw4g, conw5g], lr=0.001)
# optimizerbn = optim.Adam(netg.parameters(), lr=0.001)
for epoch in range(NUM_EPOCH) :    
    running_loss = 0.0
    learning_rate = lr_list[epoch]
#     learning_rate = 0.05
    optimizer = optim.SGD(conw1n_list, lr=0.0001)
#     optimizerbn = optim.SGD(net.parameters(), lr=learning_rate)    
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()

#         optimizerbn.zero_grad()

        # print statistics
        running_loss += loss.item()
        if i % 400 == 399:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / (i+1)))
            print(conw1n_list[0].grad.norm())
            print(conw1n_list[4].grad.norm())
            
        optimizer.step()
#         optimizerbn.step()
        
        optimizer.zero_grad()            
            
    test_acc = test(net)
    gated1_loss.append(running_loss/len(trainloader))
    gated1_accuracy.append(test_acc)

****************** Numerical Gardient update ****************




[1,   400] loss: 2.338
tensor(1.7980e-11, device='cuda:0')
tensor(3.7285e-11, device='cuda:0')
Accuracy of the network on the 10000 test images: 7 %
[2,   400] loss: 2.337
tensor(2.5748e-11, device='cuda:0')
tensor(4.8592e-11, device='cuda:0')
Accuracy of the network on the 10000 test images: 7 %
[3,   400] loss: 2.336
tensor(1.7991e-11, device='cuda:0')
tensor(4.2521e-11, device='cuda:0')
Accuracy of the network on the 10000 test images: 7 %
[4,   400] loss: 2.337
tensor(2.5702e-11, device='cuda:0')
tensor(1.0588e-11, device='cuda:0')
Accuracy of the network on the 10000 test images: 7 %
[5,   400] loss: 2.336
tensor(2.5002e-11, device='cuda:0')
tensor(3.9654e-11, device='cuda:0')
Accuracy of the network on the 10000 test images: 7 %
[6,   400] loss: 2.336
tensor(2.7104e-11, device='cuda:0')
tensor(9.1962e-11, device='cuda:0')
Accuracy of the network on the 10000 test images: 7 %
[7,   400] loss: 2.337
tensor(2.1904e-11, device='cuda:0')
tensor(5.0520e-11, device='cuda:0')
Accuracy of

In [12]:
conw1n_list[0].grad

tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0