This iPythonNotebook can be used to experiment with the Quantization Aware Training, where the user will be able to introduce quantization to any network. Further, the user can use the quantized model for evaluations on the pytorch framework itself. 

In [1]:
!pip install netron

import torch
import torch.nn as nn
import edgeai_torchmodelopt
import copy
import netron
import torchvision
from tqdm import tqdm



  from .autonotebook import tqdm as notebook_tqdm


We define the model, loss function, optimizer and the example input of what the network expects. 

In [2]:
model = torchvision.models.resnet50(weights='DEFAULT')
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

example_input = torch.rand((1, 3, 224, 224))

y = model(example_input)
print("Output Shape is : {}".format(y.shape))

Output Shape is : torch.Size([1, 1000])


In [3]:
model_export_name = "./orig_simple_network_qat.onnx"
torch.onnx.export(model, example_input, model_export_name)
netron.start(model_export_name, 8082)

Serving './orig_simple_network_qat.onnx' at http://localhost:8082


('localhost', 8082)

Here we will be wrapping our model in the QATFxModule which will be responsible for the quantization-aware-training of the models and conversion to the final quantized network. It expects us to pass the number of epochs for which the model need to be trained. It also enables bias calibration of the layers having a bias value, we can set a bias calibration factor (generally 0.01 works well) to enable it. It is suggested to perform QAT for 25-50 epochs, which helps the network stabilise. The epochs counter for the approach gets updated everytime model.train() is called. The general guidelines could be accessed [from here](../edgeai_torchmodelopt/xmodelopt/quantization/v2/docs/guidelines.md).  

In [4]:
num_epochs = 3
model = edgeai_torchmodelopt.xmodelopt.quantization.v2.QATFxModule(model, backend='qnnpack', bias_calibration_factor=0.01, total_epochs=num_epochs)



Here is the Training Step for the network, where random data is used currently just for an example. **The data, loss and optimizer should be changed to your own dataset.**

In [5]:
num_train_images = 10
for epoch in range(num_epochs):
    model.train()
    for i in tqdm(range(num_train_images)):
        optimizer.zero_grad()
        output = model(torch.rand(1,3,224,224))
        label = torch.rand(1,1000)
        loss = loss_fn(output, label) 
        loss.backward()
        optimizer.step()

Freezing BN for subsequent epochs


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:02<00:00,  3.63it/s]
100%|██████████| 10/10 [00:02<00:00,  3.63it/s]


Freezing ranges for subsequent epochs


100%|██████████| 10/10 [00:02<00:00,  3.95it/s]


We have the quantized and calibrated 8-bit network now.

In [6]:
model.eval()
print(model)

QATFxModule(
  (module): GraphModule(
    (activation_post_process_0): AdaptiveActivationFakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([0], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0039]), zero_point=tensor([0], dtype=torch.int32)
      (activation_post_process): CustomAdaptiveActivationObserverqscheme_torch_per_tensor_affine__range_shrink_percentile_0(min_val=2.0813342416658998e-05, max_val=0.9999921917915344)
    )
    (conv1): ConvBnReLU2d(
      3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (weight_fake_quant): AdaptiveWeightFakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([0], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric

In [7]:
model_export_name = "./converted_simple_network_qat.onnx"
model.export(example_input, model_export_name)
netron.start(model_export_name, 8082)


Stopping http://localhost:8082
Serving './converted_simple_network_qat.onnx' at http://localhost:8082


('localhost', 8082)

The netron might show the quantized fused operators as separate because the fake-quantized (Q-DQ) models are exported. 