# Import the necessary libraries

In [42]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

# Load the MNIST dataset

In [43]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [44]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

# Define the model

In [45]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [46]:
net = VerySimpleNet().to(device)

# Train the model

In [47]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Loaded model from disk


# Define the testing loop

In [48]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Print weights and size of the model before quantization

In [49]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[-0.0068,  0.0126, -0.0359,  ...,  0.0154, -0.0028, -0.0045],
        [-0.0141, -0.0093, -0.0048,  ..., -0.0146, -0.0003, -0.0243],
        [ 0.0251,  0.0601,  0.0120,  ...,  0.0249,  0.0464,  0.0533],
        ...,
        [ 0.0564,  0.0601,  0.0255,  ...,  0.0201,  0.0394,  0.0024],
        [-0.0070,  0.0011,  0.0332,  ...,  0.0135,  0.0135,  0.0130],
        [ 0.0103,  0.0049, -0.0092,  ...,  0.0272, -0.0221, -0.0020]],
       requires_grad=True)
torch.float32


In [50]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 360.559


In [51]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing:  15%|█▌        | 150/1000 [00:00<00:01, 709.86it/s]

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 799.89it/s]

Accuracy: 0.963





# Insert min-max observers in the model

In [52]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet,self).__init__() # Just python 2 syntax. super().__init__ works as well!
        self.quant = torch.quantization.QuantStub() # QuantStub added
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub() # QuantStub added

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [53]:
net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model, new weights not yet quantized
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

# Note how the min max values are initialised with -inf
# ao.quantization.PREPARE
net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Calibrate the model using the test set

In [54]:
# need to perform a forward pass for quant stubs to do their work
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 750.89it/s]

Accuracy: 0.963





In [55]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-53.58397674560547, max_val=34.898128509521484)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-24.331275939941406, max_val=26.62542152404785)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-28.273700714111328, max_val=20.937761306762695)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Quantize the model using the statistics collected

In [56]:
# # ao.quantization.CONVERT
net_quantized = torch.ao.quantization.convert(net_quantized)

In [57]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6967094540596008, zero_point=77, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.40123382210731506, zero_point=61, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.3874918520450592, zero_point=73, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights of the model after quantization

In [58]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[-2,  3, -8,  ...,  4, -1, -1],
        [-3, -2, -1,  ..., -3,  0, -6],
        [ 6, 14,  3,  ...,  6, 11, 12],
        ...,
        [13, 14,  6,  ...,  5,  9,  1],
        [-2,  0,  8,  ...,  3,  3,  3],
        [ 2,  1, -2,  ...,  6, -5,  0]], dtype=torch.int8)


# Compare the dequantized weights and the original weights

In [59]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[-0.0068,  0.0126, -0.0359,  ...,  0.0154, -0.0028, -0.0045],
        [-0.0141, -0.0093, -0.0048,  ..., -0.0146, -0.0003, -0.0243],
        [ 0.0251,  0.0601,  0.0120,  ...,  0.0249,  0.0464,  0.0533],
        ...,
        [ 0.0564,  0.0601,  0.0255,  ...,  0.0201,  0.0394,  0.0024],
        [-0.0070,  0.0011,  0.0332,  ...,  0.0135,  0.0135,  0.0130],
        [ 0.0103,  0.0049, -0.0092,  ...,  0.0272, -0.0221, -0.0020]],
       requires_grad=True)

Dequantized weights: 
tensor([[-0.0087,  0.0131, -0.0348,  ...,  0.0174, -0.0044, -0.0044],
        [-0.0131, -0.0087, -0.0044,  ..., -0.0131,  0.0000, -0.0261],
        [ 0.0261,  0.0609,  0.0131,  ...,  0.0261,  0.0479,  0.0522],
        ...,
        [ 0.0566,  0.0609,  0.0261,  ...,  0.0218,  0.0392,  0.0044],
        [-0.0087,  0.0000,  0.0348,  ...,  0.0131,  0.0131,  0.0131],
        [ 0.0087,  0.0044, -0.0087,  ...,  0.0261, -0.0218,  0.0000]])



# Print size and accuracy of the quantized model

In [60]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 94.955


In [61]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 774.94it/s]

Accuracy: 0.963



