In [1]:
import sys
if '..' not in sys.path:
    sys.path.append('..')

In [2]:
import torch
from torch import nn
from torchvision.datasets import mnist
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm

from source.models import LeNet, LeNetQuant

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = LeNet()
model.load_state_dict(torch.load('../models/lenet_mnist_0.989.sd'))
model.eval()

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (relu4): ReLU()
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu5): ReLU()
)

In [4]:
# use_gpu = torch.cuda.is_available()
# if use_gpu:
#     model = model.cuda()
#     print ('USE GPU')
# else:
#     print ('USE CPU')

In [5]:
model = model.cpu()

In [6]:
batch_size = 256
train_dataset = mnist.MNIST(root='../data/MNIST/train', train=True, transform=ToTensor())
test_dataset = mnist.MNIST(root='../data/MNIST/test', train=False, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [7]:
model.eval()
use_gpu = False
all_correct_num = 0
all_sample_num = 0
with torch.no_grad():
    for test_x, test_label in tqdm(test_loader):
        if use_gpu:
            test_x = test_x.cuda()
            test_label = test_label.cuda()
        predict_y = model(test_x.float()).detach()
        predict_y = torch.argmax(predict_y, axis=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.cpu().numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
print('accuracy: {:.4f}'.format(acc))

100%|██████████| 40/40 [00:00<00:00, 43.28it/s]

accuracy: 0.9889





# Eager Mode Quantization

## Dynamic Quantization

In [16]:
# model_int8 = torch.quantization.quantize_dynamic(
#     model,
#     {torch.nn.Linear, torch.nn.Conv2d, torch.nn.MaxPool2d, torch.nn.ReLU},
#     dtype=torch.qint8
# )
model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Conv2d, torch.nn.Linear},
    dtype=torch.qint8
)

In [17]:
model.conv1.weight.dtype

torch.float32

In [18]:
model.eval()
use_gpu = False
all_correct_num = 0
all_sample_num = 0
with torch.no_grad():
    for test_x, test_label in tqdm(test_loader):
        if use_gpu:
            test_x = test_x.cuda()
            test_label = test_label.cuda()
        predict_y = model(test_x.float()).detach()
        predict_y = torch.argmax(predict_y, axis=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.cpu().numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
print('accuracy: {:.4f}'.format(acc))

100%|██████████| 40/40 [00:00<00:00, 45.49it/s]

accuracy: 0.9887





## Static Quantization

In [18]:
model = LeNetQuant()
model.load_state_dict(torch.load('../models/lenet_mnist_0.989.sd'))
model.eval()

LeNetQuant(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (relu4): ReLU()
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu5): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [19]:
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

In [20]:
model = torch.quantization.fuse_modules(model, [['conv1', 'relu1'],
                                                ['conv2', 'relu2'],
                                                ['fc1', 'relu3'],
                                                ['fc2', 'relu4'],
                                                ['fc3', 'relu5']])

In [21]:
model = torch.quantization.prepare(model)

In [12]:
model.conv1

ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (activation_post_process): HistogramObserver()
)

In [22]:
# calibration on ~5% of train dataset
model.eval()
with torch.no_grad():
    for idx, (train_x, _) in tqdm(enumerate(train_loader)):
        model(train_x.float())
        if idx == 1: break

1it [00:00, 16.52it/s]


In [23]:
model = torch.quantization.convert(model)

In [24]:
model.conv1

QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.08663195371627808, zero_point=0)

In [25]:
model.conv1.weight().dtype

torch.qint8

In [26]:
model.eval()
use_gpu = False
all_correct_num = 0
all_sample_num = 0
with torch.no_grad():
    for test_x, test_label in tqdm(test_loader):
        if use_gpu:
            test_x = test_x.cuda()
            test_label = test_label.cuda()
        predict_y = model(test_x.float()).detach()
        predict_y = torch.argmax(predict_y, axis=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.cpu().numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
print('accuracy: {:.4f}'.format(acc))

100%|██████████| 40/40 [00:01<00:00, 32.40it/s]

accuracy: 0.9885



