In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader
import torch.quantization as quantization

from utils.dataset import MNIST
from utils.trainer import Trainer
from utils.model import MobileNetv2, ConvBnRelu, ConvBn

In [2]:
DATA_PATH = "data/digit-recognizer/"
MODEL_FILE = "data/model.pth"
device = "cuda"

seed = 42
batch_size = 256

In [3]:
df = pd.read_csv(DATA_PATH + "train.csv")
y = df["label"].values
X = df.drop("label", axis=1).values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=seed)

In [4]:
train_dataset = MNIST(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_dataset = MNIST(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
n_epochs = 6
lr = 0.1

torch.manual_seed(seed)
model = MobileNetv2().to(device)

In [6]:
trainer = Trainer(model, train_loader, test_loader, seed, lr=lr, momentum=0.9, weight_decay=4e-5)

for epoch in range(0, n_epochs):
    trainer.run_one_epoch(epoch)

Epoch 0: train loss 0.000991, test loss 0.000259, test accuracy 0.9804
Epoch 5: train loss 3e-06, test loss 0.000109, test accuracy 0.992


In [7]:
trainer.model = trainer.model.to("cpu")
trainer.device = "cpu"

In [8]:
%%time
_, acc = trainer.validate()
print(f"Float accruacy: {acc}")

Float accruacy: 0.992
CPU times: user 1min 27s, sys: 41.9 s, total: 2min 9s
Wall time: 33 s


In [9]:
torch.save(model.state_dict(), MODEL_FILE)

In [10]:
class QuantMobileNet(MobileNetv2):
    def __init__(self):
        super().__init__()
        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = super().forward(x)        
        x = self.dequant(x)
        return x

    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    def fuse_model(self):
        for module in self.modules():
            if isinstance(module, ConvBnRelu):
                torch.quantization.fuse_modules(module, ['conv', 'bn', 'act'], inplace=True)
            elif isinstance(module, ConvBn):
                torch.quantization.fuse_modules(module, ['conv', 'bn'], inplace=True)

# Per-tensor quantization

In [11]:
# min/max range estimation and per-tensor quantization of weights
per_tensor_quant_model = QuantMobileNet().to('cpu')
_ = per_tensor_quant_model.load_state_dict(torch.load(MODEL_FILE))
per_tensor_quant_model.eval()
per_tensor_quant_model.fuse_model()
per_tensor_quant_model

QuantMobileNet(
  (features): Sequential(
    (0): ConvBnRelu(
      (conv): ConvReLU2d(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (bn): Identity()
      (act): Identity()
    )
    (1): Block(
      (bn_layer_1x1_before): ConvBnRelu(
        (conv): ConvReLU2d(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU()
        )
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_3x3): ConvBnRelu(
        (conv): ConvReLU2d(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192)
          (1): ReLU()
        )
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_1x1_after): ConvBn(
        (conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
        (bn): Identity()
      )
      (skip_add): FloatFunctional(
        (observer): Identity()
      )
    )
    (2): Block(
      (bn_layer_1x1_before): C

In [12]:
per_tensor_quant_model.qconfig = quantization.default_qconfig
_ = torch.quantization.prepare(per_tensor_quant_model, inplace=True)

In [13]:
per_tensor_quant_trainer = Trainer(per_tensor_quant_model, train_loader, test_loader, seed, device="cpu",
                        lr=lr, momentum=0.9, weight_decay=4e-5)

In [14]:
# Calibrating the model
_ = per_tensor_quant_trainer.validate()

In [15]:
torch.quantization.convert(per_tensor_quant_trainer.model, inplace=True)

  Returning default scale and zero point "


QuantMobileNet(
  (features): Sequential(
    (0): ConvBnRelu(
      (conv): QuantizedConvReLU2d(1, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.055154748260974884, zero_point=0, padding=(1, 1))
      (bn): Identity()
      (act): Identity()
    )
    (1): Block(
      (bn_layer_1x1_before): ConvBnRelu(
        (conv): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.08988691121339798, zero_point=0)
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_3x3): ConvBnRelu(
        (conv): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(2, 2), scale=0.11506599932909012, zero_point=0, padding=(1, 1), groups=192)
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_1x1_after): ConvBn(
        (conv): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.12912055850028992, zero_point=70)
        (bn): Identity()
      )
      (skip_add): QFunctional()
    )
    (2): Block(
      (bn_layer_1x1_before): 

In [16]:
%%time
_, acc = per_tensor_quant_trainer.validate()
print(f"Per-tensor quant accruacy: {acc}")

Per-tensor quant accruacy: 0.9917
CPU times: user 10.1 s, sys: 1.51 s, total: 11.6 s
Wall time: 3.05 s


# Channel-wise quantization

In [17]:
per_channel_quant_model = QuantMobileNet().to('cpu')
_ = per_channel_quant_model.load_state_dict(torch.load(MODEL_FILE))
per_channel_quant_model.eval()
per_channel_quant_model.fuse_model()

In [18]:
# Channel-wise quant
per_channel_quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
_ = torch.quantization.prepare(per_channel_quant_model, inplace=True)

per_channel_quant_trainer = Trainer(per_channel_quant_model, train_loader, test_loader, seed, device="cpu",
                                    lr=lr, momentum=0.9, weight_decay=4e-5)

In [19]:
_ = per_channel_quant_trainer.validate()

In [20]:
torch.quantization.convert(per_channel_quant_trainer.model, inplace=True)

  Returning default scale and zero point "


QuantMobileNet(
  (features): Sequential(
    (0): ConvBnRelu(
      (conv): QuantizedConvReLU2d(1, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.04551343992352486, zero_point=0, padding=(1, 1))
      (bn): Identity()
      (act): Identity()
    )
    (1): Block(
      (bn_layer_1x1_before): ConvBnRelu(
        (conv): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.05139530077576637, zero_point=0)
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_3x3): ConvBnRelu(
        (conv): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(2, 2), scale=0.06551120430231094, zero_point=0, padding=(1, 1), groups=192)
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_1x1_after): ConvBn(
        (conv): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.10478436201810837, zero_point=68)
        (bn): Identity()
      )
      (skip_add): QFunctional()
    )
    (2): Block(
      (bn_layer_1x1_before): C

In [21]:
%%time
_, acc = per_channel_quant_trainer.validate()
print(f"Per-tensor quant accruacy: {acc}")

Per-tensor quant accruacy: 0.9919
CPU times: user 10.7 s, sys: 1.09 s, total: 11.8 s
Wall time: 3.17 s


# Quantization-aware training

In [22]:
aware_quant_model = QuantMobileNet().to('cpu')
_ = aware_quant_model.load_state_dict(torch.load(MODEL_FILE))
aware_quant_model.train()
aware_quant_model.fuse_model()

In [23]:
aware_quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
_ = torch.quantization.prepare_qat(aware_quant_model, inplace=True)

aware_quant_trainer = Trainer(aware_quant_model, train_loader, test_loader, seed, device="cpu",
                              lr=lr / 100, momentum=0.9)

In [24]:
aware_quant_trainer.run_one_epoch(0)

Epoch 0: train loss 1.5e-05, test loss 0.000108, test accuracy 0.992


In [25]:
aware_quant_trainer.model.eval()
quantization.convert(aware_quant_trainer.model, inplace=True)

QuantMobileNet(
  (features): Sequential(
    (0): ConvBnRelu(
      (conv): QuantizedConvReLU2d(1, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.04378524422645569, zero_point=0, padding=(1, 1))
      (bn): Identity()
      (act): Identity()
    )
    (1): Block(
      (bn_layer_1x1_before): ConvBnRelu(
        (conv): QuantizedConvReLU2d(32, 192, kernel_size=(1, 1), stride=(1, 1), scale=0.04960907623171806, zero_point=0)
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_3x3): ConvBnRelu(
        (conv): QuantizedConvReLU2d(192, 192, kernel_size=(3, 3), stride=(2, 2), scale=0.06527788192033768, zero_point=0, padding=(1, 1), groups=192)
        (bn): Identity()
        (act): Identity()
      )
      (bn_layer_1x1_after): ConvBn(
        (conv): QuantizedConv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), scale=0.0976630300283432, zero_point=66)
        (bn): Identity()
      )
      (skip_add): QFunctional()
    )
    (2): Block(
      (bn_layer_1x1_before): Co

In [26]:
%%time
_, acc = aware_quant_trainer.validate()
print(f"Aware quant accruacy: {acc}")

Aware quant accruacy: 0.9922
CPU times: user 11 s, sys: 999 ms, total: 12 s
Wall time: 3.26 s
