# Import the necessary libraries

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import os
from tqdm import tqdm
from pathlib import Path

# Load the MNIST dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# define device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# temporarily set the device as cpu
device = "cpu"

# Define the model

In [5]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [6]:
net = VerySimpleNet().to("cpu")

# Train the model

In [7]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28 * 28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if (
                total_iterations_limit is not None
                and total_iterations >= total_iterations_limit
            ):
                return


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print("Size (KB):", os.path.getsize("temp_delme.p") / 1e3)
    os.remove("temp_delme.p")


MODEL_FILENAME = "./data/simplenet_ptq.pth"

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print("Loaded model from disk")
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|██████████| 6000/6000 [00:13<00:00, 457.94it/s, loss=0.223]


# Define the testing loop

In [8]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f"Accuracy: {round(correct/total, 3)}")

# Print weights and size of the model before quantization

In [9]:
# Print the weights matrix of the model before quantization
print("Weights before quantization")
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[-0.0102,  0.0093, -0.0393,  ...,  0.0120, -0.0062, -0.0078],
        [-0.0199, -0.0151, -0.0106,  ..., -0.0204, -0.0061, -0.0301],
        [ 0.0103,  0.0453, -0.0028,  ...,  0.0101,  0.0316,  0.0385],
        ...,
        [ 0.0405,  0.0443,  0.0096,  ...,  0.0043,  0.0235, -0.0134],
        [ 0.0126,  0.0207,  0.0528,  ...,  0.0331,  0.0331,  0.0326],
        [ 0.0208,  0.0155,  0.0013,  ...,  0.0377, -0.0116,  0.0086]],
       requires_grad=True)
torch.float32


In [10]:
print("Size of the model before quantization")
print_size_of_model(net)

Size of the model before quantization
Size (KB): 360.559


In [11]:
print(f"Accuracy of the model before quantization: ")
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1803.29it/s]

Accuracy: 0.961





# Insert min-max observers in the model

In [25]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [31]:
# In order to avoid error:
# "Didn't find engine for operation quantized::linear_prepack NoQEngine on Apple Silicion PyTorch"
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = "qnnpack"

['qnnpack', 'none']


In [27]:
net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

print(torch.ao.quantization.default_qconfig)
net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized)  # Insert observers
net_quantized

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Calibrate the model using the test set

In [28]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1606.48it/s]

Accuracy: 0.961





In [29]:
print(f"Check statistics of the various layers")
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-52.8941764831543, max_val=37.16175079345703)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-31.748912811279297, max_val=25.99117660522461)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-26.45305824279785, max_val=22.420658111572266)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Quantize the model using the statistics collected

In [30]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [32]:
print(f"Check statistics of the various layers")
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.7091017365455627, zero_point=75, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.45464637875556946, zero_point=70, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.3848324418067932, zero_point=69, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights of the model after quantization

In [41]:
# Print the weights matrix of the model after quantization
print("Weights after quantization")
print(net_quantized.linear1.weight())
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[-0.0099,  0.0099, -0.0394,  ...,  0.0099, -0.0049, -0.0099],
        [-0.0197, -0.0148, -0.0099,  ..., -0.0197, -0.0049, -0.0296],
        [ 0.0099,  0.0444, -0.0049,  ...,  0.0099,  0.0296,  0.0394],
        ...,
        [ 0.0394,  0.0444,  0.0099,  ...,  0.0049,  0.0246, -0.0148],
        [ 0.0148,  0.0197,  0.0542,  ...,  0.0345,  0.0345,  0.0345],
        [ 0.0197,  0.0148,  0.0000,  ...,  0.0394, -0.0099,  0.0099]],
       size=(100, 784), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.004928279668092728,
       zero_point=0)
tensor([[-2,  2, -8,  ...,  2, -1, -2],
        [-4, -3, -2,  ..., -4, -1, -6],
        [ 2,  9, -1,  ...,  2,  6,  8],
        ...,
        [ 8,  9,  2,  ...,  1,  5, -3],
        [ 3,  4, 11,  ...,  7,  7,  7],
        [ 4,  3,  0,  ...,  8, -2,  2]], dtype=torch.int8)


# Compare the dequantized weights and the original weights

In [42]:
print("Original weights: ")
print(net.linear1.weight)
print("")
print(f"Dequantized weights: ")
print(torch.dequantize(net_quantized.linear1.weight()))
print("")

Original weights: 
Parameter containing:
tensor([[-0.0102,  0.0093, -0.0393,  ...,  0.0120, -0.0062, -0.0078],
        [-0.0199, -0.0151, -0.0106,  ..., -0.0204, -0.0061, -0.0301],
        [ 0.0103,  0.0453, -0.0028,  ...,  0.0101,  0.0316,  0.0385],
        ...,
        [ 0.0405,  0.0443,  0.0096,  ...,  0.0043,  0.0235, -0.0134],
        [ 0.0126,  0.0207,  0.0528,  ...,  0.0331,  0.0331,  0.0326],
        [ 0.0208,  0.0155,  0.0013,  ...,  0.0377, -0.0116,  0.0086]],
       requires_grad=True)

Dequantized weights: 
tensor([[-0.0099,  0.0099, -0.0394,  ...,  0.0099, -0.0049, -0.0099],
        [-0.0197, -0.0148, -0.0099,  ..., -0.0197, -0.0049, -0.0296],
        [ 0.0099,  0.0444, -0.0049,  ...,  0.0099,  0.0296,  0.0394],
        ...,
        [ 0.0394,  0.0444,  0.0099,  ...,  0.0049,  0.0246, -0.0148],
        [ 0.0148,  0.0197,  0.0542,  ...,  0.0345,  0.0345,  0.0345],
        [ 0.0197,  0.0148,  0.0000,  ...,  0.0394, -0.0099,  0.0099]])



# Print size and accuracy of the quantized model

In [43]:
print("Size of the model after quantization")
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 94.827


In [44]:
print("Testing the model after quantization")
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1408.96it/s]

Accuracy: 0.961



