In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import os

# Load MNIST Dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load MNIST train data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create data loader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load MNIST test data
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loader for test data
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 487kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.79MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.59MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






# Define Model

In [7]:
class SimpleNet(nn.Module):
  def __init__(self, hidden_size_1=100, hidden_size_2=100):
    super(SimpleNet, self).__init__()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

In [8]:
net = SimpleNet().to(device)

# Train The Model

In [12]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
  cross_el = nn.CrossEntropyLoss()
  optimiser = torch.optim.Adam(net.parameters(), lr=0.001)

  total_iterations = 0

  for epoch in range(epochs):
    net.train()

    loss_sum = 0
    num_iterations = 0

    data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit

    for data in data_iterator:
      num_iterations +=1
      total_iterations +=1
      x, y = data
      x = x.to(device)
      y = y.to(device)
      optimiser.zero_grad()
      output = net(x.view(-1, 28*28))
      loss = cross_el(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss=avg_loss)
      loss.backward()
      optimiser.step()

      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return

def print_size_of_model(model):
  torch.save(model.state_dict(), "temp_delme.p")
  print('size (KB):', os.path.getsize("temp_delme.p")/1e3)
  os.remove('temp_delme.p')

MODEL_FILENAME = 'SimpleNet_ptq.pt'

if Path(MODEL_FILENAME).exists():
  net.load_state_dict(torch.load(MODEL_FILENAME))
  print('Loaded Model from Disk')
else:
  train(train_loader, net, epochs=1)
  # save the model to disk
  torch.save(net.state_dict(), MODEL_FILENAME)


Epoch 1: 100%|██████████| 6000/6000 [00:43<00:00, 136.96it/s, loss=0.219]


# Define Testin Loop

In [13]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Print weights and size of the model before quantization

In [14]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0196,  0.0285,  0.0351,  ...,  0.0549,  0.0240,  0.0750],
        [ 0.0041, -0.0358, -0.0440,  ..., -0.0175, -0.0217,  0.0185],
        [ 0.0290,  0.0413,  0.0226,  ...,  0.0041,  0.0275,  0.0188],
        ...,
        [ 0.0271, -0.0004,  0.0332,  ...,  0.0025,  0.0248, -0.0060],
        [-0.0133,  0.0158,  0.0453,  ..., -0.0133,  0.0136,  0.0380],
        [ 0.0194, -0.0109, -0.0232,  ...,  0.0139, -0.0151,  0.0204]],
       requires_grad=True)
torch.float32


In [15]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
size (KB): 360.998


In [16]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:02<00:00, 400.30it/s]

Accuracy: 0.961





# Insert min-max observers in the model

In [17]:
class QuantizedSimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedSimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [18]:
net_quantized = QuantizedSimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedSimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [19]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 369.71it/s]

Accuracy: 0.961





In [20]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedSimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-45.13227081298828, max_val=35.510284423828125)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-25.628477096557617, max_val=27.260677337646484)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-31.464719772338867, max_val=23.279218673706055)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Quantize the model using the statistics collected

In [21]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [22]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedSimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6349807381629944, zero_point=71, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.41645002365112305, zero_point=62, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.43105462193489075, zero_point=73, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights of the model after quantization

In [23]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[  4,   6,   8,  ...,  12,   5,  16],
        [  1,  -8, -10,  ...,  -4,  -5,   4],
        [  6,   9,   5,  ...,   1,   6,   4],
        ...,
        [  6,   0,   7,  ...,   1,   5,  -1],
        [ -3,   3,  10,  ...,  -3,   3,   8],
        [  4,  -2,  -5,  ...,   3,  -3,   4]], dtype=torch.int8)


# Compare the dequantized weights and the original weights

In [24]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[ 0.0196,  0.0285,  0.0351,  ...,  0.0549,  0.0240,  0.0750],
        [ 0.0041, -0.0358, -0.0440,  ..., -0.0175, -0.0217,  0.0185],
        [ 0.0290,  0.0413,  0.0226,  ...,  0.0041,  0.0275,  0.0188],
        ...,
        [ 0.0271, -0.0004,  0.0332,  ...,  0.0025,  0.0248, -0.0060],
        [-0.0133,  0.0158,  0.0453,  ..., -0.0133,  0.0136,  0.0380],
        [ 0.0194, -0.0109, -0.0232,  ...,  0.0139, -0.0151,  0.0204]],
       requires_grad=True)

Dequantized weights: 
tensor([[ 0.0183,  0.0274,  0.0366,  ...,  0.0549,  0.0229,  0.0731],
        [ 0.0046, -0.0366, -0.0457,  ..., -0.0183, -0.0229,  0.0183],
        [ 0.0274,  0.0411,  0.0229,  ...,  0.0046,  0.0274,  0.0183],
        ...,
        [ 0.0274,  0.0000,  0.0320,  ...,  0.0046,  0.0229, -0.0046],
        [-0.0137,  0.0137,  0.0457,  ..., -0.0137,  0.0137,  0.0366],
        [ 0.0183, -0.0091, -0.0229,  ...,  0.0137, -0.0137,  0.0183]])



# Print size and accuracy of the quantized model

In [25]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
size (KB): 95.394


In [27]:
print('Testing the model before quantization')
test(net_quantized)

Testing the model before quantization


Testing: 100%|██████████| 1000/1000 [00:02<00:00, 407.54it/s]

Accuracy: 0.961



