In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader
import torch.quantization as quantization

from utils.dataset import MNIST
from utils.trainer import Trainer
from utils.model import MobileNetv2, ConvBnRelu, ConvBn

In [2]:
DATA_PATH = "data/digit-recognizer/"
MODEL_FILE = "data/model.pth"
device = "cuda"

seed = 42
batch_size = 256

In [3]:
df = pd.read_csv(DATA_PATH + "train.csv")
y = df["label"].values
X = df.drop("label", axis=1).values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=seed)

In [4]:
train_dataset = MNIST(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_dataset = MNIST(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
n_epochs = 6
lr = 0.1

torch.manual_seed(seed)
model = MobileNetv2().to(device)

In [6]:
trainer = Trainer(model, train_loader, test_loader, seed, lr=lr, momentum=0.9, weight_decay=4e-5)

for epoch in range(0, n_epochs):
    trainer.run_one_epoch(epoch)

Epoch 0: train loss 0.000985, test loss 0.0003, test accuracy 0.9769
Epoch 5: train loss 3e-06, test loss 0.000111, test accuracy 0.9915


In [7]:
trainer.model = trainer.model.to("cpu")
trainer.device = "cpu"

In [8]:
%%time
_, acc = trainer.validate()
print(f"Float accruacy: {acc}")

Float accruacy: 0.9915
CPU times: user 1min 25s, sys: 39.7 s, total: 2min 5s
Wall time: 31.6 s


In [9]:
torch.save(model.state_dict(), MODEL_FILE)

In [10]:
class QuantMobileNet(MobileNetv2):
    def __init__(self):
        super().__init__()
        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = super().forward(x)        
        x = self.dequant(x)
        return x

    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    def fuse_model(self):
        for module in self.modules():
            if isinstance(module, ConvBnRelu):
                torch.quantization.fuse_modules(module, ['conv', 'bn', 'act'], inplace=True)
            elif isinstance(module, ConvBn):
                torch.quantization.fuse_modules(module, ['conv', 'bn'], inplace=True)

# Per-tensor quantization

In [11]:
# min/max range estimation and per-tensor quantization of weights
per_tensor_quant_model = QuantMobileNet().to('cpu')
_ = per_tensor_quant_model.load_state_dict(torch.load(MODEL_FILE))
per_tensor_quant_model.eval()
per_tensor_quant_model.fuse_model()

In [12]:
per_tensor_quant_model.qconfig = quantization.default_qconfig
_ = torch.quantization.prepare(per_tensor_quant_model, inplace=True)

In [13]:
per_tensor_quant_trainer = Trainer(per_tensor_quant_model, train_loader, test_loader, seed, device="cpu",
                        lr=lr, momentum=0.9, weight_decay=4e-5)

In [14]:
# Calibrating the model
_ = per_tensor_quant_trainer.validate()

In [15]:
_ = torch.quantization.convert(per_tensor_quant_trainer.model, inplace=True)

  Returning default scale and zero point "


In [16]:
%%time
_, acc = per_tensor_quant_trainer.validate()
print(f"Per-tensor quant accruacy: {acc}")

Per-tensor quant accruacy: 0.9916
CPU times: user 10 s, sys: 1.24 s, total: 11.3 s
Wall time: 2.92 s


# Channel-wise quantization

In [17]:
per_channel_quant_model = QuantMobileNet().to('cpu')
_ = per_channel_quant_model.load_state_dict(torch.load(MODEL_FILE))
per_channel_quant_model.eval()
per_channel_quant_model.fuse_model()

In [18]:
# Channel-wise quant
per_channel_quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
_ = torch.quantization.prepare(per_channel_quant_model, inplace=True)

per_channel_quant_trainer = Trainer(per_channel_quant_model, train_loader, test_loader, seed, device="cpu",
                                    lr=lr, momentum=0.9, weight_decay=4e-5)

In [19]:
_ = per_channel_quant_trainer.validate()

In [20]:
_ = torch.quantization.convert(per_channel_quant_trainer.model, inplace=True)

  Returning default scale and zero point "


In [21]:
%%time
_, acc = per_channel_quant_trainer.validate()
print(f"Per-tensor quant accruacy: {acc}")

Per-tensor quant accruacy: 0.9919
CPU times: user 10.9 s, sys: 1.25 s, total: 12.2 s
Wall time: 3.2 s


# Quantization-aware training

In [22]:
aware_quant_model = QuantMobileNet().to('cpu')
_ = aware_quant_model.load_state_dict(torch.load(MODEL_FILE))
aware_quant_model.train()
aware_quant_model.fuse_model()

In [23]:
aware_quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
_ = torch.quantization.prepare_qat(aware_quant_model, inplace=True)

aware_quant_trainer = Trainer(aware_quant_model, train_loader, test_loader, seed, device="cpu",
                              lr=lr / 100, momentum=0.9)

In [24]:
aware_quant_trainer.run_one_epoch(0)

Epoch 0: train loss 1.7e-05, test loss 0.00011, test accuracy 0.992


In [25]:
aware_quant_trainer.model.eval()
_ = quantization.convert(aware_quant_trainer.model, inplace=True)

In [26]:
%%time
_, acc = aware_quant_trainer.validate()
print(f"Aware quant accruacy: {acc}")

Aware quant accruacy: 0.9915
CPU times: user 10.9 s, sys: 880 ms, total: 11.8 s
Wall time: 3.13 s
