In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [2]:
# Inherit from Function
class LinearFunction(torch.autograd.Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
#         ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        ctx.save_for_backward(input, weight, bias, output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias, output = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
#         print(grad_output.shape)
#         print((grad_output * (torch.sigmoid(output)-0.5)*2).shape)
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
#             grad_input = (grad_output * (torch.sigmoid(output)-0.5)*2).mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
#             grad_weight = (grad_output * (torch.sigmoid(output)-0.5)*2).t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

In [3]:
# Inherit from Function
class LinearGated(torch.autograd.Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
#         ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        ctx.save_for_backward(input, weight, bias, output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias, output = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
#         print(grad_output.shape)
#         print((grad_output * (torch.sigmoid(output)-0.5)*2).shape)
#         grad_output = (grad_output * (torch.sigmoid(output)-0.5)*2)
#         (torch.nn.Softmax(dim=1)(torch.sigmoid(x)))
        grad_output = grad_output * torch.nn.Softmax(dim=1)(torch.sigmoid(grad_output))
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
#             grad_input = (grad_output * torch.nn.Softmax(dim=1)(torch.sigmoid(grad_output))).mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
#             grad_weight = (grad_output * (torch.sigmoid(output)-0.5)*2).t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

In [4]:
from torch.autograd import gradcheck
linear = LinearFunction.apply
linearGated = LinearGated.apply
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
# input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
# test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
# print(test)

# input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
# test = gradcheck(linearGated, input, eps=1e-2, atol=1e-2)
# print(test)

In [5]:
dtype = torch.float
device = torch.device("cpu")

In [6]:
# # N is batch size; D_in is input dimension;
# # H is hidden dimension; D_out is output dimension.
# N, D_in, H, D_out = 4, 1000, 100, 10

# # # Create random Tensors to hold input and outputs.
# # x = torch.randn(N, D_in, device=device, dtype=dtype)
# # y = torch.randn(N, D_out, device=device, dtype=dtype)

# # Create random Tensors for weights.
# w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
# w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

In [7]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 8

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [8]:
N, D_in, H, D_out = batch_size, 32*32*3, 100, 10
# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
# w3 = w1.clone()
# w4 = w2.clone()
# print(w3)
w3 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w4 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.dense1 = LinearFunction.apply
        self.dense2 = LinearFunction.apply
#         self.w1 = w1
#         self.w2 = w2

    def forward(self, x, w1, w2):
#         x = self.pool(F.relu(self.convZ(x)))
#         x = x.view(-1, 1 * 1 * 1)
#         x = self.fc1(x)
        x = self.dense2(F.relu(self.dense1(x, w1.t())), w2.t())
        return x
    
class NetG(nn.Module):
    def __init__(self):
        super(NetG, self).__init__()
        self.dense1 = LinearGated.apply
        self.dense2 = LinearGated.apply
#         self.w1 = w1
#         self.w2 = w2

    def forward(self, x, w3, w4):
#         x = self.pool(F.relu(self.convZ(x)))
#         x = x.view(-1, 1 * 1 * 1)
#         x = self.fc1(x)
        x = (self.dense2(F.relu(self.dense1(x, w3.t())), w4.t()))
        return x    

In [10]:
# x = torch.randn(4,3)
# print(x)
# print(torch.sigmoid(x))
# print(torch.nn.Softmax(dim=1)(torch.sigmoid(x)))

In [11]:
net = Net()
netg = NetG()

# Define a Loss function and optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [12]:
def test (model, w1, w2) :
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs = torch.flatten(inputs,1)  

            # zero the parameter gradients
        #     optimizer.zero_grad()   
            # forward + backward + optimize
            outputs = model(inputs, w1, w2)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

In [13]:
learning_rate = 0.0005
for epoch in range(40) :    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = torch.flatten(inputs,1)  

        # zero the parameter gradients
    #     optimizer.zero_grad()   
        # forward + backward + optimize
        outputs = netg(inputs, w3, w4)
    #     print(outputs.shape)
        loss = criterion(outputs, labels)
#         print(loss)
        loss.backward()
    #     optimizer.step()   
#         print(w3.grad)

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 100:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0 
        with torch.no_grad():
            # Update weights using gradient descent
            w3 -= learning_rate * w3.grad
            w4 -= learning_rate * w4.grad

            # Manually zero the gradients after running the backward pass
            w3.grad.zero_()
            w4.grad.zero_()   
    test(netg, w3, w4)
            
# correct = 0
# total = 0
# # since we're not training, we don't need to calculate the gradients for our outputs
# with torch.no_grad():
#     for data in testloader:
#         inputs, labels = data
#         inputs = torch.flatten(inputs,1)  

#         # zero the parameter gradients
#     #     optimizer.zero_grad()   
#         # forward + backward + optimize
#         outputs = netg(inputs, w3, w4)
#         # the class with the highest energy is what we choose as prediction
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

# print('Accuracy of the network on the 10000 test images: %d %%' % (
#     100 * correct / total))

[1,   101] loss: 14.467
[1,  2101] loss: 219.724
[1,  4101] loss: 170.735
[1,  6101] loss: 155.532
Accuracy of the network on the 10000 test images: 14 %
[2,   101] loss: 7.353
[2,  2101] loss: 140.239
[2,  4101] loss: 135.327
[2,  6101] loss: 128.152
Accuracy of the network on the 10000 test images: 16 %
[3,   101] loss: 6.011
[3,  2101] loss: 120.839
[3,  4101] loss: 116.535
[3,  6101] loss: 109.979
Accuracy of the network on the 10000 test images: 17 %
[4,   101] loss: 5.228
[4,  2101] loss: 106.396
[4,  4101] loss: 102.468
[4,  6101] loss: 98.861
Accuracy of the network on the 10000 test images: 18 %
[5,   101] loss: 4.679
[5,  2101] loss: 94.413
[5,  4101] loss: 91.893
[5,  6101] loss: 90.437
Accuracy of the network on the 10000 test images: 19 %
[6,   101] loss: 4.689
[6,  2101] loss: 85.829
[6,  4101] loss: 83.042
[6,  6101] loss: 81.169
Accuracy of the network on the 10000 test images: 20 %
[7,   101] loss: 4.053
[7,  2101] loss: 77.795
[7,  4101] loss: 76.565
[7,  6101] loss: 

In [14]:
learning_rate = 0.0002
for epoch in range(20) :    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = torch.flatten(inputs,1)  

        # zero the parameter gradients
    #     optimizer.zero_grad()   
        # forward + backward + optimize
        outputs = net(inputs, w1, w2)
    #     print(outputs.shape)
        loss = criterion(outputs, labels)
#         print(loss)
        loss.backward()
    #     optimizer.step()   

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 100:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0 
        with torch.no_grad():
            # Update weights using gradient descent
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after running the backward pass
            w1.grad.zero_()
            w2.grad.zero_()    
    test(net, w1, w2)
            
# correct = 0
# total = 0
# # since we're not training, we don't need to calculate the gradients for our outputs
# with torch.no_grad():
#     for data in testloader:
#         inputs, labels = data
#         inputs = torch.flatten(inputs,1)  

#         # zero the parameter gradients
#     #     optimizer.zero_grad()   
#         # forward + backward + optimize
#         outputs = net(inputs, w1, w2)
#         # the class with the highest energy is what we choose as prediction
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

# print('Accuracy of the network on the 10000 test images: %d %%' % (
#     100 * correct / total))            

[1,   101] loss: 13.337
[1,  2101] loss: 156.951
[1,  4101] loss: 116.333
[1,  6101] loss: 99.324
Accuracy of the network on the 10000 test images: 22 %
[2,   101] loss: 4.338
[2,  2101] loss: 83.050
[2,  4101] loss: 75.765
[2,  6101] loss: 69.156
Accuracy of the network on the 10000 test images: 24 %
[3,   101] loss: 3.179
[3,  2101] loss: 60.795
[3,  4101] loss: 55.546
[3,  6101] loss: 50.304
Accuracy of the network on the 10000 test images: 26 %
[4,   101] loss: 2.451
[4,  2101] loss: 45.347
[4,  4101] loss: 41.644
[4,  6101] loss: 38.821
Accuracy of the network on the 10000 test images: 27 %
[5,   101] loss: 1.773
[5,  2101] loss: 34.686
[5,  4101] loss: 31.533
[5,  6101] loss: 29.228
Accuracy of the network on the 10000 test images: 27 %
[6,   101] loss: 1.345
[6,  2101] loss: 26.164
[6,  4101] loss: 23.363
[6,  6101] loss: 21.471
Accuracy of the network on the 10000 test images: 28 %
[7,   101] loss: 1.101
[7,  2101] loss: 18.897
[7,  4101] loss: 16.932
[7,  6101] loss: 15.060
Ac