# Import the necessary libraries

In [1]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os
import torch.ao.quantization as quant

# Load MNIST dataset

In [2]:
torch.manual_seed(101)

<torch._C.Generator at 0x2469e1a8e30>

In [8]:
# For MNIST...
# mean value = 0.0.1307
# var = 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='data/', train = True, transform = transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 64, shuffle = True)

mnist_testset = datasets.MNIST(root='data/', train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size = 64, shuffle = True)

device = "cpu"

# Define the model
The architecture of model and the training and testing processes are same, only few modification to show the training result.

* `fuse_model`: Fuse two layers together, since qint8 cannot be dealt with normal arithmetic.

In [9]:
class SimpleNet(nn.Module):
    def __init__(self, neuron_1 = 64, neuron_2 = 64):
        super().__init__()
        self.quant = quant.QuantStub()
        self.linear1 = nn.Linear(28 * 28, neuron_1)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(neuron_1, neuron_2)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(neuron_2, 10)
        self.dequant = quant.DeQuantStub()
        

    def forward(self, img):
        x = img.reshape(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu1(self.linear1(x))
        x = self.relu2(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x
    
    def fuse_model(self):
        torch.ao.quantization.fuse_modules(self, [['linear1', 'relu1'], ['linear2', 'relu2']], inplace=True)

In [10]:
net = SimpleNet().to(device)

# Train the model

* `print_size_of_model`: Calculate the size of model
    1. Save the model
    2. `os.path.getsize`
    3. Remove the saving model

In [11]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc = f'Epoch {epoch + 1}') # shows the training process
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            
            output = net(x)
            loss = loss_fn(output, y)
            loss_sum += loss
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss) # post_fix for tqdm

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
    
def print_size_of_model(model):
    torch.save(model.state_dict(), 'temp_delme.p')
    print('Size (KB):', os.path.getsize('temp_delme.p') / 1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=5)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)



Epoch 1: 100%|██████████| 938/938 [00:20<00:00, 46.54it/s, loss=tensor(0.3070, grad_fn=<DivBackward0>)]
Epoch 2: 100%|██████████| 938/938 [00:20<00:00, 45.28it/s, loss=tensor(0.1362, grad_fn=<DivBackward0>)]
Epoch 3: 100%|██████████| 938/938 [00:20<00:00, 46.24it/s, loss=tensor(0.0981, grad_fn=<DivBackward0>)]
Epoch 4: 100%|██████████| 938/938 [00:20<00:00, 45.74it/s, loss=tensor(0.0758, grad_fn=<DivBackward0>)]
Epoch 5: 100%|██████████| 938/938 [00:20<00:00, 45.65it/s, loss=tensor(0.0629, grad_fn=<DivBackward0>)]


# Define the testing loop

In [12]:
def test(model, total_iterations=None):
    model.eval()
    model.to("cpu")

    correct = 0
    total = 0
    iterations = 0

    with torch.inference_mode():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to("cpu")
            y = y.to("cpu")

            output = model(x)

            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1

            if total_iterations is not None and iterations >= total_iterations:
                break

    print(f'Accuracy: {round(correct / total, 3)}')


# Print weights and size of the model before quantization

In [13]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0128, -0.0026, -0.0295,  ..., -0.0194,  0.0109, -0.0084],
        [ 0.0464,  0.0048,  0.0133,  ...,  0.0520,  0.0025, -0.0120],
        [ 0.0395,  0.0317, -0.0237,  ..., -0.0278, -0.0207,  0.0277],
        ...,
        [ 0.0363,  0.0409,  0.0370,  ...,  0.0066,  0.0480,  0.0488],
        [-0.0257, -0.0119,  0.0268,  ...,  0.0260, -0.0261, -0.0349],
        [ 0.0340,  0.0216,  0.0462,  ...,  0.0387,  0.0491,  0.0531]],
       requires_grad=True)
torch.float32


In [14]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 222.886


In [15]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 157/157 [00:02<00:00, 71.69it/s]

Accuracy: 0.971





# Start to Quantization

In [16]:
print(type(net.linear1))

<class 'torch.nn.modules.linear.Linear'>


1. `get_default_qconfig`: Call API to start quantization. Define the configuration as "fbgemm".

In [17]:
import torch.ao.quantization

net.fuse_model()
net.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
print(net.qconfig)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})


2. `prepare`: Insert the observers for quantization.

In [18]:

torch.ao.quantization.prepare(net, inplace = True)
# Calibrate first
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Inverted Residual Block:After observer insertion \n\n', net.linear1)

Post Training Quantization Prepare: Inserting Observers

 Inverted Residual Block:After observer insertion 

 LinearReLU(
  (0): Linear(in_features=784, out_features=64, bias=True)
  (1): ReLU()
  (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
)




3. Calibration: Training (testing) the model for one epoch to get the parameters in obervers.

In [19]:
test(net)
print('Post Training Quantization: Calibration done')

Testing: 100%|██████████| 157/157 [00:02<00:00, 62.71it/s]

Accuracy: 0.971
Post Training Quantization: Calibration done





4. `convert`: Convert the model into quantized model.

In [20]:
torch.ao.quantization.convert(net, inplace = True)
print('Post Training Quantization: Convert done')
print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',net.linear1)

Post Training Quantization: Convert done

 Inverted Residual Block: After fusion and quantization, note fused modules: 

 QuantizedLinearReLU(in_features=784, out_features=64, scale=0.22460556030273438, zero_point=0, qscheme=torch.per_channel_affine)


In [21]:
print("Size of model after quantization")
print_size_of_model(net)

Size of model after quantization
Size (KB): 64.354


In [22]:
test(net)

Testing: 100%|██████████| 157/157 [00:02<00:00, 73.35it/s]

Accuracy: 0.972





If the model is successfully quantized, the name of the layer will start with `quantized...`.

In [23]:
print(net)

SimpleNet(
  (quant): Quantize(scale=tensor([0.0255]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinearReLU(in_features=784, out_features=64, scale=0.22460556030273438, zero_point=0, qscheme=torch.per_channel_affine)
  (relu1): Identity()
  (linear2): QuantizedLinearReLU(in_features=64, out_features=64, scale=0.1565181165933609, zero_point=0, qscheme=torch.per_channel_affine)
  (relu2): Identity()
  (linear3): QuantizedLinear(in_features=64, out_features=10, scale=0.44150465726852417, zero_point=76, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)


* **scale**: The step size between quantized levels.
* **zero_point**: The integer value in the quantized range that maps to real zero.

In [25]:
print(torch.dequantize(net.linear1.weight()))
print("----------------------------------------------------------------------------")
print(net.linear1.scale)
print(net.linear1.zero_point)


tensor([[ 0.0131, -0.0022, -0.0305,  ..., -0.0196,  0.0109, -0.0087],
        [ 0.0472,  0.0054,  0.0127,  ...,  0.0526,  0.0018, -0.0127],
        [ 0.0386,  0.0331, -0.0248,  ..., -0.0276, -0.0220,  0.0276],
        ...,
        [ 0.0358,  0.0409,  0.0358,  ...,  0.0077,  0.0486,  0.0486],
        [-0.0252, -0.0126,  0.0270,  ...,  0.0252, -0.0252, -0.0342],
        [ 0.0350,  0.0225,  0.0450,  ...,  0.0375,  0.0501,  0.0526]])
----------------------------------------------------------------------------
0.22460556030273438
0


* `model.linear1.weight().int_repr()`: Accessing the quantized tensor; use .int_repr() to retrieve its underlying int8 values.

In [26]:
weight_q = net.linear1.weight()
weight_int8 = weight_q.int_repr()
print(weight_int8)


tensor([[  6,  -1, -14,  ...,  -9,   5,  -4],
        [ 26,   3,   7,  ...,  29,   1,  -7],
        [ 14,  12,  -9,  ..., -10,  -8,  10],
        ...,
        [ 14,  16,  14,  ...,   3,  19,  19],
        [-14,  -7,  15,  ...,  14, -14, -19],
        [ 14,   9,  18,  ...,  15,  20,  21]], dtype=torch.int8)
