In [1]:
import datetime
import os
import sys
import time

import torch
import torch.utils.data
from torch import nn

from tqdm import tqdm

import torchvision
from torchvision import transforms

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

from absl import logging
logging.set_verbosity(logging.FATAL)  # Disable logging as they are too noisy in notebook



In [2]:
# For simplicity, import train and eval functions from the train script from torchvision instead of copything them here
# Download torchvision from https://github.com/pytorch/vision
sys.path.append("/raid/skyw/models/torchvision/references/classification/")
from train import evaluate, train_one_epoch, load_data

## Set default QuantDescriptor to use histogram based calibration for activation

In [3]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

## Initialize quantized modules

In [4]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

## Create model with pretrained weight

In [5]:
model = torchvision.models.resnet50(pretrained=True, progress=False)
model.cuda()

ResNet(
  (conv1): QuantConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator quant)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): QuantConv2d(
        64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
        (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator quant)
        (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator quant)
      )
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv

## Create data loader

In [6]:
data_path = "/raid/data/imagenet/imagenet_pytorch"
batch_size = 512

traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(traindir, valdir, False, False)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size,
    sampler=train_sampler, num_workers=4, pin_memory=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size,
    sampler=test_sampler, num_workers=4, pin_memory=True)


Loading data
Loading training data
Took 3.580507755279541
Loading validation data
Creating data loaders


## Calibrate the model

In [7]:
def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
            
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
#             print(F"{name:40}: {module}")
    model.cuda()

In [8]:
# It is a bit slow since we collect histograms on CPU
with torch.no_grad():
    collect_stats(model, data_loader, num_batches=2)
    compute_amax(model, method="percentile", percentile=99.99)

100%|██████████| 2/2 [04:50<00:00, 111.13s/it]

## Now evaluate the calibrated model

In [9]:
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)
    
# Save the model
torch.save(model.state_dict(), "/tmp/quant_resnet50-calibrated.pth")

Test:  [ 0/98]  eta: 0:05:53  loss: 0.5656 (0.5656)  acc1: 85.7422 (85.7422)  acc5: 96.0938 (96.0938)  time: 3.6079  data: 2.8152  max mem: 5880
Test:  [20/98]  eta: 0:01:07  loss: 0.6741 (0.6825)  acc1: 82.8125 (82.4219)  acc5: 95.8984 (95.7682)  time: 0.7343  data: 0.0002  max mem: 5882
Test:  [40/98]  eta: 0:00:46  loss: 0.6995 (0.7157)  acc1: 80.0781 (81.4024)  acc5: 96.0938 (95.7412)  time: 0.7226  data: 0.0002  max mem: 5882
Test:  [60/98]  eta: 0:00:29  loss: 1.1064 (0.8590)  acc1: 71.4844 (78.2627)  acc5: 91.0156 (94.1150)  time: 0.7259  data: 0.0002  max mem: 5882
Test:  [80/98]  eta: 0:00:13  loss: 1.1220 (0.9372)  acc1: 72.4609 (76.7072)  acc5: 89.6484 (93.1375)  time: 0.7220  data: 0.0002  max mem: 5882
Test: Total time: 0:01:13
 * Acc@1 76.138 Acc@5 92.916


## We can also try different calibrations and see which one works the best

In [10]:
with torch.no_grad():
    compute_amax(model, method="percentile", percentile=99.9)
    evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

Test:  [ 0/98]  eta: 0:05:27  loss: 0.6037 (0.6037)  acc1: 84.9609 (84.9609)  acc5: 95.3125 (95.3125)  time: 3.3411  data: 2.6190  max mem: 5882
Test:  [20/98]  eta: 0:01:06  loss: 0.6760 (0.7041)  acc1: 81.2500 (81.7522)  acc5: 95.7031 (95.4892)  time: 0.7243  data: 0.0002  max mem: 5882
Test:  [40/98]  eta: 0:00:45  loss: 0.7241 (0.7351)  acc1: 79.1016 (80.7784)  acc5: 95.8984 (95.4459)  time: 0.7243  data: 0.0002  max mem: 5882
Test:  [60/98]  eta: 0:00:29  loss: 1.1162 (0.8793)  acc1: 71.4844 (77.6383)  acc5: 90.8203 (93.7948)  time: 0.7204  data: 0.0002  max mem: 5882
Test:  [80/98]  eta: 0:00:13  loss: 1.1498 (0.9603)  acc1: 71.4844 (76.0368)  acc5: 89.4531 (92.7156)  time: 0.7164  data: 0.0002  max mem: 5882
Test: Total time: 0:01:12
 * Acc@1 75.438 Acc@5 92.486


In [11]:
with torch.no_grad():
    for method in ["mse", "entropy"]:
        print(F"{method} calibration")
        compute_amax(model, method=method)
        evaluate(model, criterion, data_loader_test, device="cuda", print_freq=20)

mse calibration
Test:  [ 0/98]  eta: 0:06:34  loss: 0.5700 (0.5700)  acc1: 85.1562 (85.1562)  acc5: 96.2891 (96.2891)  time: 4.0243  data: 3.3231  max mem: 5882
Test:  [20/98]  eta: 0:01:08  loss: 0.6758 (0.6838)  acc1: 82.8125 (82.5707)  acc5: 96.0938 (95.7868)  time: 0.7204  data: 0.0002  max mem: 5882
Test:  [40/98]  eta: 0:00:46  loss: 0.7047 (0.7163)  acc1: 80.2734 (81.4834)  acc5: 96.2891 (95.7746)  time: 0.7178  data: 0.0002  max mem: 5882
Test:  [60/98]  eta: 0:00:29  loss: 1.1127 (0.8585)  acc1: 71.0938 (78.3395)  acc5: 90.8203 (94.1278)  time: 0.7192  data: 0.0002  max mem: 5882
Test:  [80/98]  eta: 0:00:13  loss: 1.1261 (0.9367)  acc1: 72.6562 (76.7530)  acc5: 89.8438 (93.1785)  time: 0.7176  data: 0.0002  max mem: 5882
Test: Total time: 0:01:13
 * Acc@1 76.186 Acc@5 92.926
entropy calibration
Test:  [ 0/98]  eta: 0:05:28  loss: 0.5648 (0.5648)  acc1: 85.3516 (85.3516)  acc5: 96.0938 (96.0938)  time: 3.3558  data: 2.6268  max mem: 5882
Test:  [20/98]  eta: 0:01:05  loss: 0.6