In [3]:
import os
import torch
import torch.nn as nn
import torchvision.datasets as datasets
from torchvision.transforms import transforms
from tqdm import tqdm
from pathlib import Path


# Load MNIST dataset

In [4]:
# For reproducibility
_=torch.manual_seed(0)
# Converting image into Pytoch Tensor
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])

# Load MNIST
mnist_trainset=datasets.MNIST(root='/data',train=True,download=True,transform=transform)
# Create dataloader
train_loader=torch.utils.data.DataLoader(mnist_trainset,batch_size=10,shuffle=True)

# Load MNIST test set
mnist_testset=datasets.MNIST(root='/data',train=False,download=True,transform=transform)
#Create Dataloader
test_loader=torch.utils.data.DataLoader(mnist_testset,batch_size=10,shuffle=True)

device='cuda'

# Define the Model

In [5]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net = VerySimpleNet().to(device)

# Insert min-max observe in the model

In [6]:
# This adds fake quant nodes around layers during training.
# So it behaves better after quantization
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers
net_quantized

VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Train the model

In [7]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

train(train_loader, net_quantized, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:35<00:00, 166.73it/s, loss=0.223]


#Define testing Loop

In [8]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

In [9]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5414018034934998, max_val=0.31874144077301025)
    (activation_post_process): MinMaxObserver(min_val=-39.6721305847168, max_val=33.988555908203125)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.41715824604034424, max_val=0.3596401810646057)
    (activation_post_process): MinMaxObserver(min_val=-38.520084381103516, max_val=21.97124671936035)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.44865426421165466, max_val=0.21375349164009094)
    (activation_post_process): MinMaxObserver(min_val=-31.449342727661133, max_val=22.950424194335938)
  )
  (relu): ReLU()
  (dequant): DeQua

In [10]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)

In [11]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.580005407333374, zero_point=68, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.47630971670150757, zero_point=81, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.42834460735321045, zero_point=73, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [12]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[ 3,  7, -4,  ...,  8,  4,  3],
        [-5, -4, -3,  ..., -5, -2, -8],
        [ 2, 10, -1,  ...,  2,  7,  9],
        ...,
        [10, 11,  3,  ...,  2,  6, -3],
        [-2,  0,  8,  ...,  3,  3,  3],
        [ 6,  4,  1,  ..., 10, -2,  3]], device='cuda:0', dtype=torch.int8)


In [None]:
print('Testing the model after quantization')
test(net_quantized)