In [1]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import Module
import numpy as np
from torch.nn._functions.quantize import quantize
from torch.nn import functional as F


import torch.utils.data
import torchvision
from torchvision import datasets, transforms
import torch.optim as optim

from torch.nn.init import kaiming_normal
from torch.nn.init import xavier_normal

In [2]:
A = torch.rand(4, 4)
B = quantize(A, (16, 14))
print(B)


 0.8856  0.2258  0.4086  0.2658
 0.7846  0.5686  0.9352  0.0784
 0.0083  0.1743  0.1168  0.3219
 0.0137  0.0757  0.7559  0.4158
[torch.FloatTensor of size (4,4)]



In [3]:
def kaiming_normal_quantized(tensor, bits, a=0, mode="fan_in"):
    kaiming_normal(tensor, a, mode)
    tensor = quantize(tensor, bits)
def xavier_normal_quantized(tensor, bits):
    xavier_normal(tensor)
    tensor = quantize(tensor, bits)

In [4]:
class Net(Module):
    def __init__(self, forward_bits, backward_bits):
        super(Net, self).__init__()
        self.Q = nn.Quantize(forward_bits, backward_bits)
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                kaiming_normal_quantized(m.weight.data, bits)
    def forward(self, x):
        out = self.Q(x)
        out = F.relu(self.conv1(out))
        out = self.Q(out)
        out = F.max_pool2d(out,2)
        out = F.relu(self.conv2(out))
        out = self.Q(out)
        out = F.max_pool2d(out,2)
        
        out = out.view(-1, 400)
        out = F.relu(self.fc1(out))
        out = self.Q(out)
        out = F.relu(self.fc2(out))
        out = self.Q(out)
        out = self.fc3(out)
        return F.log_softmax(out, dim=1)

In [5]:
bath_size = 100
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=bath_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=bath_size, shuffle=True)

In [6]:
bits = (16, 14)
model = Net(bits, bits)

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(model.fc2.weight.grad[1:10, 1:10])
            print(model.fc2.weight[1:10, 1:10])
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


for epoch in range(1, 11):
    train(epoch)
    test()



 0.0000  0.0000  0.0035  0.0050  0.0063  0.0002  0.0000  0.0004  0.0000
 0.0000  0.0011  0.1079  0.0904  0.0897  0.0042  0.0000  0.0047  0.0000
 0.0000  0.0005  0.0182  0.0068  0.0226 -0.0001  0.0000  0.0011  0.0000
 0.0000 -0.0001  0.0008 -0.0104  0.0005  0.0000  0.0000 -0.0000  0.0000
 0.0000  0.0025  0.1868  0.1359  0.2204  0.0038  0.0000  0.0093  0.0000
 0.0000 -0.0004  0.0005  0.0014 -0.0077 -0.0007  0.0000 -0.0021  0.0000
 0.0000  0.0000 -0.0005 -0.0006 -0.0011  0.0000  0.0000  0.0000  0.0000
 0.0000  0.0000 -0.0185 -0.0165 -0.0233 -0.0035  0.0000  0.0000  0.0000
 0.0000  0.0000  0.0231  0.0101  0.0150  0.0000  0.0000  0.0008  0.0000
[torch.FloatTensor of size (9,9)]


 0.0040 -0.0110 -0.0364  0.1867  0.1234  0.0281 -0.1563 -0.1068  0.1807
-0.1601  0.0703  0.1003  0.1594  0.0251 -0.1717  0.1340 -0.3247  0.2029
 0.1342 -0.1126  0.0196 -0.1970  0.2420 -0.1135  0.1025 -0.1828 -0.0175
-0.1981 -0.0149 -0.0186 -0.0402  0.0784 -0.1909  0.1880 -0.2523 -0.1320
 0.1785  0.0371  0.0016 -0.




1.00000e-02 *
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.0891  0.0000 -0.2856  0.2823 -0.0312 -0.0082  0.1570 -1.2528 -1.1107
 -0.1513 -0.1164 -2.2051 -0.3224 -0.0318 -0.0388 -0.2013  0.0217 -0.0081
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0176  0.0000  0.0511  0.0583 -0.0129  0.0000  0.0147 -0.1099 -0.3253
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size (9,9)]


-0.2651 -0.0183 -0.0241  0.1803  0.1182  0.0124 -0.1461 -0.1104  0.0658
-0.1526  0.0690 -0.1052  0.1180 -0.0556 -0.1710  0.0820 -0.3344  0.1720
 0.1073 -0.2200 -0.0132 -0.1443  0.2224 -0.1084  0.3719 -0.2742  0.2012
-0.3493 -0.0008  0.2077  0.0126  0.0786 -0.1792  0.0563 -0.3985 -0.2167
 0.1




Test set: Average loss: 1.6851, Accuracy: 3743/10000 (37%)


1.00000e-03 *
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000 -7.4461  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000 -0.2441  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size (9,9)]


-0.2652 -0.0185 -0.0118  0.1807  0.0737  0.0124 -0.2971 -0.2149  0.0567
-0.1526  0.0690 -0.1050  0.1136 -0.0865 -0.1710  0.0480 -0.3382  0.1617
 0.1053 -0.2337  0.2269 -0.1343  0.0813 -0.1047  0.5428 -0.2364  0.2239
-0.3531  0.0234


Test set: Average loss: 1.5438, Accuracy: 4086/10000 (40%)


1.00000e-02 *
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0244  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  2.8319  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.6958  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size (9,9)]


-0.2652 -0.0185 -0.0118  0.1807  0.0737  0.0124 -0.2971 -0.2149  0.0567
-0.1526  0.0690 -0.1050  0.1136 -0.0865 -0.1710  0.0480 -0.3382  0.1617
 0.1053 -0.2337  0.0346 -0.1343  0.0813 -0.1047  0.5428 -0.2364  0.2239
-0.3531  0.0234


Test set: Average loss: 1.6018, Accuracy: 3756/10000 (37%)


1.00000e-03 *
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000 -3.2958  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000 -0.3662  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size (9,9)]


-0.2652 -0.0185 -0.0118  0.1807  0.0737  0.0124 -0.2971 -0.2149  0.0567
-0.1526  0.0690 -0.1050  0.1136 -0.0865 -0.1710  0.0480 -0.3382  0.1617
 0.1053 -0.2337  0.0996 -0.1343  0.0813 -0.1047  0.5428 -0.2364  0.2239
-0.3531  0.0234

KeyboardInterrupt: 