In [2]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [3]:
_ = torch.manual_seed(0)

In [4]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

mnist_testset = datasets.MNIST(root='./data', train=False, transform=transforms)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=False)

device = 'cpu'

In [5]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [6]:
net = VerySimpleNet().to(device)

In [7]:
import torch.ao.quantization


net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net)
net_quantized

VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [8]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss) 
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
        
def print_size_of_model(model):
    torch.save(model.state_dict(), 'temp_delme.p')
    print('Size (KB)', os.path.getsize('temp_delme.p')/1e3)

train(train_loader, net_quantized, epochs=1)
        

Epoch 1: 100%|██████████| 6000/6000 [00:35<00:00, 169.81it/s, loss=0.223]


In [9]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    
    iterations = 0
    
    model.eval()
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f"Accuracy: {correct / total, 3}")

In [10]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)

In [11]:
print("Statistcs of the various layers")
net_quantized

Statistcs of the various layers


VerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256], device='cuda:0'), zero_point=tensor([17], device='cuda:0'), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.5800052881240845, zero_point=68, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.4763098359107971, zero_point=81, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.42834457755088806, zero_point=73, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [12]:
print("Weights before quantization")
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[ 3,  7, -4,  ...,  8,  4,  3],
        [-5, -4, -3,  ..., -5, -2, -8],
        [ 2, 10, -1,  ...,  2,  7,  9],
        ...,
        [10, 11,  3,  ...,  2,  6, -3],
        [-2,  0,  8,  ...,  3,  3,  3],
        [ 6,  4,  1,  ..., 10, -2,  3]], device='cuda:0', dtype=torch.int8)


In [13]:
print("Testing the model after quantization")
test(net_quantized)

Testing the model after quantization


Testing:   0%|          | 0/1000 [00:00<?, ?it/s]


RuntimeError: Unable to find an engine to execute this computation Quantized Linear Cudnn