In [1]:
# https://discuss.pytorch.org/t/call-backward-on-function-inside-a-backpropagation-step/3793
# https://discuss.pytorch.org/t/implementing-a-custom-convolution-using-conv2d-input-and-conv2d-weight/18556
# https://discuss.pytorch.org/t/implementing-a-custom-convolution-using-conv2d-input-and-conv2d-weight/18556/21

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np
import sys
sys.path.append('./')
from gated import Conv2dFunctionG, Conv2dFunction, CustomBatchNormManualFunction, Net, NetA, NetG

In [3]:
# from torch.autograd import gradcheck
# conv = Conv2dFunction.apply
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
# input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
# test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
# print(test)


In [4]:
# device = torch.device("cpu")
device = torch.device("cuda")
dtype = torch.float
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# N, C_in, C_out, K_size = batch_size, 3, 12, 3
# Create random Tensors for weights.
conw1 = torch.randn(8,3,5,5, device=device, dtype=dtype, requires_grad=True)
conw2 = torch.randn(32,8,3,3, device=device, dtype=dtype, requires_grad=True)
conw3 = torch.randn(128,32,3,3, device=device, dtype=dtype, requires_grad=True)
conw4 = torch.randn(128,128,3,3, device=device, dtype=dtype, requires_grad=True)
conw5 = torch.randn(10,128,1,1, device=device, dtype=dtype, requires_grad=True)

# weight for normal
conw1 = torch.nn.init.xavier_uniform_(conw1, gain=1.0)
conw2 = torch.nn.init.xavier_uniform_(conw2, gain=1.0)
conw3 = torch.nn.init.xavier_uniform_(conw3, gain=1.0)
conw4 = torch.nn.init.xavier_uniform_(conw4, gain=1.0)
conw5 = torch.nn.init.xavier_uniform_(conw5, gain=1.0)

# weights of batch norm for custom backward, normal case 
bn1w = torch.ones(8, device=device, dtype=dtype, requires_grad=True)
bn1b = torch.zeros(8, device=device, dtype=dtype, requires_grad=True)
bn2w = torch.ones(32, device=device, dtype=dtype, requires_grad=True)
bn2b = torch.zeros(32, device=device, dtype=dtype, requires_grad=True)
bn3w = torch.ones(128, device=device, dtype=dtype, requires_grad=True)
bn3b = torch.zeros(128, device=device, dtype=dtype, requires_grad=True)
bn4w = torch.ones(128, device=device, dtype=dtype, requires_grad=True)
bn4b = torch.zeros(128, device=device, dtype=dtype, requires_grad=True)
conw_list = [conw1, conw2, conw3, conw4, conw5]
bn_list = [ bn1w, bn1b, bn2w, bn2b, bn3w, bn3b, bn4w, bn4b,]

# weight for gated
# conw1g = conw1.clone().detatch(requires_grad=True)
conw1g = torch.tensor(conw1, device=device, dtype=dtype, requires_grad=True)
conw2g = torch.tensor(conw2, device=device, dtype=dtype, requires_grad=True)
conw3g = torch.tensor(conw3, device=device, dtype=dtype, requires_grad=True)
conw4g = torch.tensor(conw4, device=device, dtype=dtype, requires_grad=True)
conw5g = torch.tensor(conw5, device=device, dtype=dtype, requires_grad=True)

# weight for adam 
conw1a = torch.tensor(conw1, device=device, dtype=dtype, requires_grad=True)
conw2a = torch.tensor(conw2, device=device, dtype=dtype, requires_grad=True)
conw3a = torch.tensor(conw3, device=device, dtype=dtype, requires_grad=True)
conw4a = torch.tensor(conw4, device=device, dtype=dtype, requires_grad=True)
conw5a = torch.tensor(conw5, device=device, dtype=dtype, requires_grad=True)



# weights of batch norm for custom backward, gated case 
bn1wg = torch.ones(8, device=device, dtype=dtype, requires_grad=True)
bn1bg = torch.zeros(8, device=device, dtype=dtype, requires_grad=True)
bn2wg = torch.ones(32, device=device, dtype=dtype, requires_grad=True)
bn2bg = torch.zeros(32, device=device, dtype=dtype, requires_grad=True)
bn3wg = torch.ones(128, device=device, dtype=dtype, requires_grad=True)
bn3bg = torch.zeros(128, device=device, dtype=dtype, requires_grad=True)
bn4wg = torch.ones(128, device=device, dtype=dtype, requires_grad=True)
bn4bg = torch.zeros(128, device=device, dtype=dtype, requires_grad=True)


# print(conw1[0][0])
# print(torch.nn.init.xavier_uniform_(conw1, gain=1.0)[0][0])
# print(conw2)



In [6]:
conw1 is conw1g

False

In [7]:
# print(conw1)

In [8]:
# print(conw1g)

In [9]:
net = Net(conw_list, bn_list).to(device)
# netg = NetG(conw_list, bn_list).to(device)
# neta = NetA(conw_list, bn_list).to(device)
criterion = nn.CrossEntropyLoss().to(device)

In [10]:
# for k in neta.state_dict() :
#     if 'bn1' in k :
#         print(k)
#         print(neta.state_dict()[k])
        
# for k in netg.state_dict() :
#     if 'bn1' in k :
#         print(k)
#         print(netg.state_dict()[k])   
for k in net.parameters() :
    print(k.shape)
#     if 'bn2' in k :
#         print(k)
#         print(net.state_dict()[k])              

torch.Size([10, 128])
torch.Size([10])


In [11]:
# for p in neta.parameters() :
#     print(p[0][0][0])

In [12]:
# for param in neta.parameters() :
#     print(param.shape)

# # parameter for adam should be same with other model
# neta_dict = neta.state_dict()
# for p in neta_dict :
#     if 'conv1' in p :
#         neta_dict[p] = conw1a
#     elif 'conv2' in p :
#         neta_dict[p] = conw2a
#     elif 'conv3' in p :
#         neta_dict[p] = conw3a
#     elif 'conv4' in p :
#         neta_dict[p] = conw4a
#     elif 'conv5' in p :
#         neta_dict[p] = conw5a        
# #     print(p)
# #     print(neta.state_dict()[p].shape)
# #     print()
# neta.load_state_dict(neta_dict)

In [13]:
# for p in neta.parameters() :
#     print(p[0][0][0])
# print(conw1a[0][0][0])    

In [14]:
image, label = iter(trainloader).next()
print(image.shape)

torch.Size([64, 3, 32, 32])


In [15]:
a = torch.randn(64,128,4,4).view(64,4,4,128)
b = torch.randn(128)
# a * b

In [16]:
image, labels = iter(trainloader).next()
outputs = net(image.to(device)).to(device)
print(outputs.shape)
# print(outputs.sum(dim=1))
# print(outputs)
loss = criterion(outputs, labels.to(device))
loss.backward()
# print(torch.nn.Softmax(dim=1)(outputs).sum(dim=1))
# outputs
print(net.bn1w.grad)
print(net.bn4w.grad[0:5])

RuntimeError: The size of tensor a (8) must match the size of tensor b (32) at non-singleton dimension 3

In [None]:
# net.print_weight()
net.update_weight(0.05)
print(net.bn1w)
print(net.bn4w[0:5])

In [None]:
# conw5.grad

In [None]:
def test (model, w1=None, w2=None, w3=None, w4=None, w5=None) :
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            if w1 == None :
                outputs = model(inputs.to(device))
            else :
                outputs = model(inputs.to(device))
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))
    return (100 * correct / total)

In [None]:
NUM_EPOCH = 15

In [None]:
lr_list = [0.1, 0.1, 0.05, 0.05, 0.05, 0.02, 0.02, 
           0.02, 0.01, 0.01, 0.01, 0.005,
           0.005, 0.002, 0.001]
print('******************normal case****************')
# lr_list = [0.05] * NUM_EPOCH
criterion = nn.CrossEntropyLoss().to(device)
normal_loss = []
normal_accuracy = []
for epoch in range(NUM_EPOCH) :    
    running_loss = 0.0
    learning_rate = lr_list[epoch]
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
#         print(loss)
        loss.backward()   

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / (i+1)))
#             net.print_weight()
            
        net.update_weight(learning_rate)             
            
    test_acc = test(net)
    normal_loss.append(running_loss/len(trainloader))
    normal_accuracy.append(test_acc)    

In [None]:
# print('******************normal case with adam****************')
# learning_rate = 0.001
# criterion = nn.CrossEntropyLoss().to(device)
# optimizer = optim.Adam(neta.parameters(), lr=learning_rate)
# adam_loss = []
# adam_accuracy = []
# for epoch in range(NUM_EPOCH) :    
#     running_loss = 0.0
#     neta.train()
#     for i, data in enumerate(trainloader, 0):
#         inputs, labels = data

#         outputs = neta(inputs.to(device))
#         loss = criterion(outputs, labels.to(device))
#         loss.backward()  
#         optimizer.step()
#         optimizer.zero_grad()

#         # print statistics
#         running_loss += loss.item()
#         if i % 200 == 199:    # print every 2000 mini-batches
#             print('[%d, %5d] loss: %.3f' %
#                   (epoch + 1, i + 1, running_loss / (i+1)))
    
#     neta.eval()
#     test_acc = test(neta)
#     adam_loss.append(running_loss/len(trainloader))
#     adam_accuracy.append(test_acc)

In [None]:
lr_list = [0.1, 0.1, 0.05, 0.05, 0.05, 0.02, 0.02, 
           0.02, 0.01, 0.01, 0.01, 0.005,
           0.005, 0.002, 0.001]
print('******************grad gated****************')
# lr_list = [0.05] * NUM_EPOCH
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(neta.parameters(), lr=0.001)
gated_loss = []
gated_accuracy = []
for epoch in range(NUM_EPOCH) :    
    running_loss = 0.0
    learning_rate = lr_list[epoch]
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        outputs = netg(inputs.to(device), conw1g, conw2g, conw3g, conw4g, conw5g)
#         print(outputs.shape)
#         print(labels.shape)
        loss = criterion(outputs, labels.to(device))
#         print(loss)
        loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / (i+1)))
            
        with torch.no_grad():
            # Update weights using gradient descent
            conw1g -= learning_rate * conw1g.grad
            conw2g -= learning_rate * conw2g.grad
            conw3g -= learning_rate * conw3g.grad
            conw4g -= learning_rate * conw4g.grad
            conw5g -= learning_rate * conw5g.grad

            # Manually zero the gradients after running the backward pass
            conw1g.grad.zero_()
            conw2g.grad.zero_()
            conw3g.grad.zero_()
            conw4g.grad.zero_()
            conw5g.grad.zero_()
            
    test_acc = test(netg, conw1g, conw2g, conw3g, conw4g, conw5g)
    gated_loss.append(running_loss/len(trainloader))
    gated_accuracy.append(test_acc)

In [None]:
for k in netg.state_dict() :
    if 'bn1' in k :
        print(k)
        print(netg.state_dict()[k])        

In [None]:
for k in net.state_dict() :
    if 'bn1' in k :
        print(k)
        print(net.state_dict()[k])      

In [None]:
# normal_loss
# normal_accuracy

In [None]:
# gated_loss
# gated_accuracy

In [None]:
# adam_loss
# adam_accuracy

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

X = np.arange(0,15,1)
plt.plot(X, adam_loss,  color='red')
plt.plot(X, gated_loss, color='blue')
plt.plot(X, normal_loss, color='green')
plt.legend(['adam', 'gated', 'normal'])
plt.show()


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

X = np.arange(0,15,1)
plt.plot(X, adam_accuracy,  color='red')
plt.plot(X, gated_accuracy, color='blue')
plt.plot(X, normal_accuracy, color='green')
plt.legend(['adam', 'gated', 'normal'])
plt.show()

In [None]:
d = torch.randn(2,4,3,3)
d.mean([0,2,3])

In [None]:
d.mean(dim=0).mean(dim=2).mean(dim=1)