Following the calibrate_quant_resnet50 example, now we fine tune the model

In [1]:
import datetime
import os
import sys
import time

import torch
import torch.utils.data
from torch import nn

from tqdm import tqdm

import torchvision
from torchvision import transforms

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

from absl import logging
logging.set_verbosity(logging.FATAL)  # Disable logging as they are too noisy in notebook

W0608 21:25:39.018203 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.
W0608 21:25:39.019082 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.
W0608 21:25:39.019555 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.
W0608 21:25:39.020030 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.
W0608 21:25:39.020492 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.
W0608 21:25:39.020947 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.
W0608 21:25:39.021392 140228493526848 tensor_quant.py:96] Meaning of axis has changed since v2.0. Make sure to update.


In [2]:
# For simplicity, import train and eval functions from the train script from torchvision instead of copything them here
sys.path.append("/raid/skyw/models/torchvision/references/classification/")
from train import evaluate, train_one_epoch, load_data

In [3]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

# Create and load the calibrated model
model = torchvision.models.resnet50()
model.load_state_dict(torch.load("/tmp/quant_resnet50-calibrated.pth"))
model.cuda()

QuantResNet(
  (conv1): QuantConv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=2.6387 calibrator=MaxCalibrator(track_amax=False) quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=[0.0000, 0.7817](64) calibrator=MaxCalibrator(track_amax=False) quant)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): QuantConv2d(
        64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
        (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=2.9730 calibrator=MaxCalibrator(track_amax=False) quant)
        (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=[0.0000, 0.7266](64) calibrator=MaxCalibrator(track_amax=False) quant)
      )
      (bn1): Ba

In [4]:
data_path = "/raid/data/imagenet/imagenet_pytorch"

traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(traindir, valdir, False, False)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=256,
    sampler=train_sampler, num_workers=4, pin_memory=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=256,
    sampler=test_sampler, num_workers=4, pin_memory=True)

## Quantized fine tuning
Let's fine tune the model with fake quantization. We only fine tune for 1 epoch as an example.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=128,
    sampler=train_sampler, num_workers=16, pin_memory=True)

# Training takes about one and half hour per epoch on single V100
train_one_epoch(model, criterion, optimizer, data_loader, "cuda", 0, 100)

## Evaluate the fine tuned model

In [6]:
with torch.no_grad():
    evaluate(model, criterion, data_loader_test, device="cuda")

Test:  [  0/196]  eta: 0:09:36  loss: 0.4680 (0.4680)  acc1: 85.9375 (85.9375)  acc5: 98.0469 (98.0469)  time: 2.9406  data: 2.1852  max mem: 16096
Test:  [ 10/196]  eta: 0:01:58  loss: 0.6694 (0.6522)  acc1: 83.2031 (82.9545)  acc5: 96.0938 (96.1293)  time: 0.6346  data: 0.1988  max mem: 16096
Test:  [ 20/196]  eta: 0:01:30  loss: 0.6738 (0.6733)  acc1: 82.0312 (82.4777)  acc5: 95.7031 (95.7961)  time: 0.3928  data: 0.0001  max mem: 16096
Test:  [ 30/196]  eta: 0:01:18  loss: 0.6219 (0.6322)  acc1: 84.3750 (83.9718)  acc5: 95.7031 (96.0181)  time: 0.3859  data: 0.0001  max mem: 16096
Test:  [ 40/196]  eta: 0:01:10  loss: 0.6801 (0.6750)  acc1: 81.6406 (82.5934)  acc5: 95.7031 (95.9604)  time: 0.3861  data: 0.0001  max mem: 16096
Test:  [ 50/196]  eta: 0:01:04  loss: 0.6937 (0.6724)  acc1: 80.0781 (82.3529)  acc5: 96.8750 (96.0938)  time: 0.3834  data: 0.0001  max mem: 16096
Test:  [ 60/196]  eta: 0:00:58  loss: 0.7149 (0.6849)  acc1: 80.0781 (81.9864)  acc5: 96.4844 (96.1066)  time: 0

After only 1 epoch of quantized fine tuning, top-1 improved from ~76.1 to 76.426. Train longer with lr anealing can improve accuracy futher