In [219]:
# Imports 
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import os
torch.backends.quantized.engine = 'qnnpack'
torch.set_printoptions(sci_mode = False)

## Basic Quantization

In [220]:
# Generating Random Params
params = torch.tensor(np.random.uniform(low = -10, 
                                        high = 20, 
                                        size = 200000))


In [221]:
def asymmetric_quantization(params: torch.tensor, n_bits: int, percentile: float = 0) -> [torch.tensor]:
    
    # quantizations init
    
    if not percentile:
        alpha = torch.max(params)
        beta = torch.min(params)
    
    else:
        alpha = torch.quantile(params, percentile, interpolation='linear')
        beta = torch.quantile(params, 1 - percentile, interpolation='linear') 
        
    scale = (alpha - beta) / (2 ** n_bits - 1)
    z = -1 * torch.round((beta / scale))


    # quantization
    q_params = (torch.round(params / scale) + z)
    
    # value clipping
    q_params_clamped = torch.clamp(q_params, min = 0, max = 2 ** n_bits - 1)
    
    # dequantization
    dq_params = (q_params_clamped - z) * scale
    
    # MSE error in quantization
    error = torch.mean((params - dq_params) ** 2)
    
    return dq_params, error
    
dq_params, mse = asymmetric_quantization(params, 8, 0)
dq_params_percentile, mse_percentile = asymmetric_quantization(params, 8, 0.9999)

In [222]:
print(f"MSE Error in min-max qunatization {mse.item()}")
print(f"MSE Error in percentile(0.99) quantization {mse_percentile.item()}")

MSE Error in min-max qunatization 0.001153559216033612
MSE Error in percentile(0.99) quantization 0.0011527203556808393


## Quantization in Neural Networks

### Quantization Aware Training 

In [223]:


torch.manual_seed(0)

transform = transforms.Compose([transforms.ToTensor()])

mnist_trainset = datasets.MNIST(root = "./data", train = True, download = True, transform = transform)

mnist_testset = datasets.MNIST(root = "./data", train = False, download = True, transform = transform)


train_dataloader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)
test_dataloader = torch.utils.data.DataLoader(mnist_testset, batch_size = 10, shuffle = False)


device = 'cpu'

In [224]:
# Qunatization Aware Training
class TinyNet(nn.Module):
    
    def __init__(self, hidden1_size = 100, hidden2_size = 100):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(28 * 28, hidden1_size)
        self.linear2 = nn.Linear(hidden1_size, hidden2_size)
        self.linear3 = nn.Linear(hidden2_size, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

    
quantization_aware_model = TinyNet().to(device)

In [225]:
quantization_aware_model.qconfig = torch.ao.quantization.default_qat_qconfig

quantization_aware_model.train()

quantized_aware_model = torch.ao.quantization.prepare_qat(quantization_aware_model)

quantized_aware_model

TinyNet(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=127, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.ui

In [226]:
def train(train_loader, network, epochs=5):
    ce_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        network.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = network(x.view(-1, 28*28))
            loss = ce_loss(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

train(train_dataloader, quantized_aware_model, epochs=1)

Epoch 1: 100%|█████████████████| 6000/6000 [00:14<00:00, 409.09it/s, loss=0.257]


In [227]:
quantized_aware_model

TinyNet(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=127, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0079]), zero_point=tensor([0], dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0039]), zero_point=tensor([0], dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.49763086438179016, max_val=0.4351730942726135)
    )
    (activation_post_process): FakeQuantize(
      fake_qu

In [232]:
def eval(model):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_dataloader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
    print(f'Accuracy: {round(correct/total, 3)}')

In [230]:
quantized_aware_model.eval()

quantized_aware_model_INT = torch.ao.quantization.convert(quantized_aware_model)

In [257]:
eval(quantized_aware_model)

Testing: 100%|████████████████████████████| 1000/1000 [00:00<00:00, 1003.62it/s]

Accuracy: 0.954





In [258]:
eval(quantized_aware_model_INT)

Testing: 100%|████████████████████████████| 1000/1000 [00:00<00:00, 2560.14it/s]

Accuracy: 0.954





In [253]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    size = os.path.getsize("temp_delme.p")/1e3
    os.remove('temp_delme.p')
    return size

In [256]:
NQ_model_size = print_size_of_model(quantization_aware_model)
Q_model_size = print_size_of_model(quantized_aware_model_INT)
print(f"Non Quantized Model Size (KB):{print_size_of_model(quantization_aware_model)} ")
print(f"Quantized Model Size (KB):{print_size_of_model(quantized_aware_model_INT)} ")
print(f"Compression Rate: {Q_model_size / NQ_model_size}")

Non Quantized Model Size (KB):361.062 
Quantized Model Size (KB):95.266 
Compression Rate: 0.26384942198292816


### Post Training Static Quantization

In [271]:
# Post Training Static Quantization Model
class TinyNet(nn.Module):
    
    def __init__(self, hidden1_size = 100, hidden2_size = 100):
        super().__init__()
#         self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(28 * 28, hidden1_size)
        self.linear2 = nn.Linear(hidden1_size, hidden2_size)
        self.linear3 = nn.Linear(hidden2_size, 10)
#         self.dequant = torch.ao.quantization.DeQuantStub()
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28 * 28)
#         x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
#         x = self.dequant(x)
        return x

In [272]:
Non_Quant_model = TinyNet()

In [273]:
train(train_dataloader, Non_Quant_model, epochs = 1)

Epoch 1: 100%|██████████████████| 6000/6000 [00:09<00:00, 666.14it/s, loss=0.24]


In [274]:
eval(Non_Quant_model)

Testing: 100%|████████████████████████████| 1000/1000 [00:00<00:00, 3471.12it/s]

Accuracy: 0.965





In [275]:
class QuantizedTinyNet(nn.Module):
    
    def __init__(self, hidden1_size = 100, hidden2_size = 100):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(28 * 28, hidden1_size)
        self.linear2 = nn.Linear(hidden1_size, hidden2_size)
        self.linear3 = nn.Linear(hidden2_size, 10)
        self.dequant = torch.ao.quantization.DeQuantStub()
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [277]:
Quant_model = QuantizedTinyNet().to(device)

Quant_model.load_state_dict(Non_Quant_model.state_dict())

Quant_model.eval()

Quant_model.qconfig = torch.ao.quantization.get_default_qconfig()

PTSQ_model = torch.ao.quantization.prepare(Quant_model)

PTSQ_model

QuantizedTinyNet(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
  (relu): ReLU()
)

In [278]:
eval(PTSQ_model)

Testing: 100%|████████████████████████████| 1000/1000 [00:00<00:00, 1241.36it/s]

Accuracy: 0.965





In [281]:
PTSQ_model_INT = torch.ao.quantization.convert(PTSQ_model)

In [282]:
eval(PTSQ_model_INT)

Testing: 100%|████████████████████████████| 1000/1000 [00:00<00:00, 2527.05it/s]

Accuracy: 0.966



