In [1]:
import torch
from torch import nn
from torch.autograd.function import Function
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Module
from random import random
import numpy as np
import torch.utils.data
from torchvision import datasets, transforms
import torch.optim as optim
from torch.nn.init import kaiming_normal
from torch.nn.init import xavier_normal

In [16]:
class Quantize(Function):
    
    @staticmethod
    def forward(ctx, input, forward_bits, backward_bits, mode="Stochastic"):
        ctx.backward_bits = backward_bits
        ctx.mode = mode
        return quantize(input, forward_bits, mode)
    
    @staticmethod
    def backward(ctx, dldy):
        out = quantize(dldy, ctx.backward_bits, ctx.mode)
        return dldy, None, None

def quantize(data, bits, mode="Stochastic"):
    """
    Quantzie a Tensor.
    """
    data_bits, precision_bits = bits
    if mode=="Nearest":
        temp = data.clone()
        return temp.round()
    elif mode=="Stochastic":
        temp = data / 2**(-precision_bits)
        def saturate(data_bits, d):
            magnitude_bits = data_bits - 1
            upper = 2**magnitude_bits-1
            lower = -2**magnitude_bits
            return torch.clamp(d, lower, upper)
        temp = (temp+random()).floor()
        temp = saturate(data_bits, temp)
        temp *= 2**(-precision_bits)
        return temp
    else: raise ValueError("Invalid quantization mode")

In [11]:
def kaiming_normal_quantized(tensor, bits, a=0, mode="fan_in"):
    kaiming_normal(tensor, a, mode)
    tensor = quantize(tensor, bits)
def xavier_normal_quantized(tensor, bits):
    xavier_normal(tensor)
    tensor = quantize(tensor, bits)

In [15]:
class Net(Module):
    def __init__(self, bits):
        super(Net, self).__init__()
        self.bits = bits
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                xavier_normal_quantized(m.weight.data, bits)
                xavier_normal(m.weight.data)
    def forward(self, x):
        quantize = Quantize.apply
        Q = lambda x: quantize(x, self.bits, self.bits)
        out = Q(x)
        out = x
        out = F.relu(self.conv1(out))
        out = F.max_pool2d(out,2)
        out = Q(out)
        
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out,2)
        out = Q(out)
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        return F.log_softmax(out, dim=1)

In [4]:
bath_size = 100
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=bath_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=bath_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [20]:
model = Net((16, 12))

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(model.conv1.weight.grad[0])
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


for epoch in range(1, 11):
    train(epoch)
    test()


Variable containing:
(0 ,.,.) = 
1.00000e-02 *
   0.9106 -1.4016 -2.4634 -2.9015 -4.4610
  -1.4743 -4.8777 -6.1824 -7.0564 -6.0621
  -2.0619 -4.1446 -5.8531 -6.2602 -3.2471
  -1.1353 -2.5930 -2.1119 -0.7680  0.6779
  -0.8059 -1.6521 -1.0822  0.7824  0.6881
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
  -4.6509 -5.0046 -3.2611 -0.6318  3.5595
  -2.4335 -2.6877 -2.5309  0.4319  3.3839
  -0.5292 -0.9974 -1.0130 -0.0842  1.5143
  -1.8066 -1.4887 -2.7068 -3.6993 -0.8902
  -1.5677 -1.8491 -2.7894 -2.7369 -0.4047
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
1.00000e-03 *
   3.2116  2.0720  1.9691 -0.1986  0.0143
   3.4894  0.2916 -0.5197  2.5097  1.5529
   2.5205  3.6874  5.4208  4.4934  1.4591
  -1.4871 -1.1585  6.5221  8.3587  0.7424
  -3.7186  3.0363  9.9834  7.7677 -3.2018
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
  -0.1417  0.5711  0.1066 -0.1108 -0.5642
  -1.0106 -0.1525 -0.0315 -0.6795 -


Test set: Average loss: 0.2998, Accuracy: 8940/10000 (89%)

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
   1.1141  0.0244 -2.2096 -2.3788 -0.1648
   1.0326 -1.2767 -1.6701 -0.4063  2.2081
  -0.2019 -1.9554 -0.8604  2.4016  2.0351
  -2.3406 -3.7140 -2.6225 -1.2702 -0.6758
  -6.8020 -5.8188 -4.9400 -2.0924  0.6023
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
   0.8790  1.2739  0.8780  0.1790 -0.3878
   0.9377  1.0271  0.8021  0.6070  0.1216
   0.2717  0.5780  1.2556  1.0662  0.6504
   0.0021  1.0441  1.2430  1.0705  0.4098
   0.0626  0.9614  1.4951  0.6564 -0.0208
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
   1.3174  2.3683  2.6123  0.1211 -3.1353
  -1.2785 -0.1558  1.1260 -2.9184 -2.6676
   0.4183  0.4449 -2.2221 -0.7640  0.6886
   0.9887 -0.5341 -2.4657  1.3313 -0.0530
   1.5900 -1.4034 -0.7208 -1.9747 -1.1616
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
1.00000e-02 *
  -0.6127  1.149


Test set: Average loss: 0.9061, Accuracy: 7092/10000 (71%)

Variable containing:
(0 ,.,.) = 
  0.0425  0.0604  0.1777  0.2720  0.2135
  0.0119  0.1162  0.2108  0.2434  0.1423
  0.0281  0.1140  0.1678  0.1124  0.0367
  0.0742  0.0901  0.0877  0.0289 -0.0109
  0.0759  0.0855  0.0427  0.0200  0.0154
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
 -0.0128 -0.0678 -0.0792 -0.0732 -0.0265
 -0.0036 -0.0825 -0.0939 -0.0606 -0.0161
 -0.0278 -0.1054 -0.1460 -0.0988 -0.0288
 -0.0595 -0.1657 -0.1594 -0.0996 -0.0150
 -0.1048 -0.1668 -0.1721 -0.0506 -0.0064
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
  0.2140  0.2656  0.1991  0.1420  0.1115
  0.2271  0.2626  0.1976  0.1195  0.0495
  0.1787  0.1789  0.1069  0.0364  0.0129
  0.0804  0.0306 -0.0221 -0.0400 -0.0100
 -0.0201 -0.0274 -0.0168  0.0366  0.0772
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
 -0.0165 -0.0259 -0.0452 -0.0248 -0.0255
 -0.0373 -0.0780 -0.0772 -0.0717 -0.0382
 -0.0