In [1]:
import torch
import torchsummary
from models import get_model

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

#### [Post Training Dynamic Quantization](https://pytorch.org/docs/stable/quantization.html)

This is the simplest to apply form of quantization where the weights are quantized ahead of time but the activations are dynamically quantized during inference. This is used for situations where the model execution time is dominated by loading weights from memory rather than computing the matrix multiplications. (e.g., small batch sizes)

In [6]:
# Quantization applied on linear layers

modelname = "mobilenet"
modelpath = "./rsc/outputs/saved_model/mobilenet.pth"

model = get_model(modelname, 2, True)
model.to(DEVICE)

model.load_state_dict(torch.load(modelpath))

model_int8 = torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

torchsummary.summary(model, (3,224,224))
torchsummary.summary(model_int8, (3,224,224))

dummy_input_fp32 = torch.rand((1,3,224,224))
dummy_input_fp32.to(DEVICE)

output = model_int8(dummy_input_fp32)

Sequential(
  (0): Dropout(p=0.2, inplace=False)
  (1): Linear(in_features=1280, out_features=1000, bias=True)
)
Linear(in_features=1280, out_features=1000, bias=True)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
             ReLU6-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             288
       BatchNorm2d-5         [-1, 32, 112, 112]              64
             ReLU6-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 16, 112, 112]             512
       BatchNorm2d-8         [-1, 16, 112, 112]              32
  InvertedResidual-9         [-1, 16, 112, 112]               0
           Conv2d-10         [-1, 96, 112, 112]           1,536
      BatchNorm2d-11         [-1, 96, 112, 112]             192

In [7]:
# quantization applied to linear layers (but increase linear layer size)
# layers (fp32) --> linear_layer (int8) --> activation (fp32) --> ...

modelname = "mobilenet"
modelpath = "./rsc/outputs/saved_model/mobilenet.pth"

model = get_model(modelname, 1000, True)
model.to(DEVICE)

model_int8 = torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

torchsummary.summary(model, (3,224,224))
torchsummary.summary(model_int8, (3,224,224))

dummy_input_fp32 = torch.rand((1,3,224,224))
dummy_input_fp32.to(DEVICE)

output = model_int8(dummy_input_fp32)

Sequential(
  (0): Dropout(p=0.2, inplace=False)
  (1): Linear(in_features=1280, out_features=1000, bias=True)
)
Linear(in_features=1280, out_features=1000, bias=True)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
             ReLU6-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             288
       BatchNorm2d-5         [-1, 32, 112, 112]              64
             ReLU6-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 16, 112, 112]             512
       BatchNorm2d-8         [-1, 16, 112, 112]              32
  InvertedResidual-9         [-1, 16, 112, 112]               0
           Conv2d-10         [-1, 96, 112, 112]           1,536
      BatchNorm2d-11         [-1, 96, 112, 112]             192

#### [Quantization Aware Training](https://pytorch.org/docs/stable/quantization.html)

Quantization Aware Training (QAT) models the effects of quantization during training allowing for higher accuracy compared to other quantization methods. We can do QAT for static, dynamic or weight only quantization. During training, all calculations are done in floating point, with fake_quant modules modeling the effects of quantization by clamping and rounding to simulate the effects of INT8.

In [14]:
import torchvision
import copy
from torch.quantization import get_default_qat_qconfig, quantize_fx
from train import train_step, validation_step

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(255),
        torchvision.transforms.CenterCrop(225),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[.485,.456,.406], std=[.229,.224,.225])
    ])

train_dataset = torchvision.datasets.ImageFolder(root="./../data", transform=transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)

model = get_model("mobilenet", 2, True)
qconfig = {"": get_default_qat_qconfig("qnnpack")}
example_inputs = (torch.randn(1, 3, 224, 224),)
model = quantize_fx.prepare_qat_fx(model.train(), qconfig, example_inputs)
model.to(DEVICE)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(0, 1):
    _,_ = train_step(model, train_loader, optimizer, criterion, device=DEVICE)

    model_quantized = copy.deepcopy(model)
    model_quantized.to(torch.device("cpu"))
    model_quantized = quantize_fx.convert_fx(model_quantized.eval())
    torch.save(model_quantized.state_dict, "./rsc/outputs/saved_model/quantized.pth")



100%|██████████| 100/100 [00:28<00:00,  3.54it/s]
