Quantization

In [159]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os

In [160]:

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x


In [161]:
input_size = 8
hidden_size = 32
output_size = 16

In [162]:
model = SimpleNN(input_size, hidden_size, output_size)
model

SimpleNN(
  (fc1): Linear(in_features=8, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=16, bias=True)
  (sigmoid): Sigmoid()
  (relu): ReLU()
)

In [163]:
# quantization Types
# 1. Symmetric Quantization
# 2. Asymmetric Quantization

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')
    
def clamp(x, min_val, max_val):
    x[x<min_val] = min_val  
    x[x>max_val] = max_val
    return x

def asymmetric_quantization(x, num_bits):
    # min-max quantization
    alpha = torch.max(torch.abs(x))
    beta = 0
    scale = (alpha - beta) / (2**num_bits - 1)
    zero_point = -1 * torch.round(beta / scale)
    quantized = clamp(torch.round(x / scale), -2**(num_bits-1), 2**(num_bits-1) - 1)
    return quantized, scale, zero_point

def asymmetric_dequantization(x, scale, zero_point):
    return x * scale + zero_point

def symmetric_quantization(x, num_bits):
    alpha = torch.max(torch.abs(x))
    scale = alpha / (2**(num_bits-1) - 1)
    quantized = torch.round(x / scale)
    return quantized, scale

def symmetric_dequantization(x, scale):
    return x * scale

def quantize_percentile(x, num_bits, percentile=99.9):
    alpha = np.percentile(x.numpy(), percentile)
    scale = alpha / (2**(num_bits-1) - 1)
    quantized = torch.round(x / scale)
    return quantized, scale

def dequantize_percentile(x, scale):
    return x * scale

In [164]:
# qunatized NN
class QuantizedSimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(QuantizedSimpleNN, self).__init__()
        self.qunat = torch.quantization.QuantStub()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.qunat(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.dequant(x)
        return x


In [165]:
qunat_model = QuantizedSimpleNN(input_size, hidden_size, output_size)
qunat_model.load_state_dict(model.state_dict())
qunat_model.eval()

# Set quantization backend (required for conversion)
# Use 'qnnpack' for ARM CPUs (like Apple Silicon) or 'fbgemm' for x86 CPUs
torch.backends.quantized.engine = 'qnnpack'  # or 'fbgemm' for x86
qunat_model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')  # or 'fbgemm'
qunat_model = torch.ao.quantization.prepare(qunat_model) # Insert observers

model, qunat_model

(SimpleNN(
   (fc1): Linear(in_features=8, out_features=32, bias=True)
   (fc2): Linear(in_features=32, out_features=16, bias=True)
   (sigmoid): Sigmoid()
   (relu): ReLU()
 ),
 QuantizedSimpleNN(
   (qunat): QuantStub(
     (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
   )
   (fc1): Linear(
     in_features=8, out_features=32, bias=True
     (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
   )
   (fc2): Linear(
     in_features=32, out_features=16, bias=True
     (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
   )
   (sigmoid): Sigmoid()
   (relu): ReLU()
   (dequant): DeQuantStub()
 ))

In [166]:
# calibration step, estimate quantization param
dummy_input = torch.randn(10, input_size)
qunat_model(dummy_input)
# model(dummy_input)
qunat_model


QuantizedSimpleNN(
  (qunat): QuantStub(
    (activation_post_process): HistogramObserver(min_val=-2.006115674972534, max_val=2.4371373653411865)
  )
  (fc1): Linear(
    in_features=8, out_features=32, bias=True
    (activation_post_process): HistogramObserver(min_val=-1.6454459428787231, max_val=1.8079413175582886)
  )
  (fc2): Linear(
    in_features=32, out_features=16, bias=True
    (activation_post_process): HistogramObserver(min_val=-0.6881318092346191, max_val=0.9596636891365051)
  )
  (sigmoid): Sigmoid()
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [167]:
qunat_model = torch.ao.quantization.convert(qunat_model)
print(qunat_model)
print_size_of_model(model)
print_size_of_model(qunat_model)

QuantizedSimpleNN(
  (qunat): Quantize(scale=tensor([0.0174]), zero_point=tensor([115]), dtype=torch.quint8)
  (fc1): QuantizedLinear(in_features=8, out_features=32, scale=0.013536081649363041, zero_point=122, qscheme=torch.per_tensor_affine)
  (fc2): QuantizedLinear(in_features=32, out_features=16, scale=0.006458787713199854, zero_point=107, qscheme=torch.per_tensor_affine)
  (sigmoid): Sigmoid()
  (relu): ReLU()
  (dequant): DeQuantize()
)
Size (KB): 5.416
Size (KB): 4.962


In [168]:
torch.int_repr(qunat_model.fc1.weight()) #data.dtype, model.fc1.weight.data.dtype

tensor([[  97,    7,  115,  -96,  -45,  113, -122, -119],
        [  21,  -88,  -36,    6,  -58,    0,   46,  -98],
        [ 117,   55,   22, -105, -117,  -85,  -14,  -10],
        [ -71,   73,   95,   80,    1,   86,  -86,   90],
        [  56,  -49,  -31,   97, -105,  -84,  108,  -96],
        [  12,   34,   -5,   17,  -87,   29, -128,   -9],
        [  85,    5, -119,   40,   51,  -58,  -15,  -55],
        [   0,   56,  -26,   15,  -92,   61,  -86,  -91],
        [  -4,   73,   96,   40,   31, -113,  -44,  106],
        [  38,  -71,  101,  -61,   89,  126, -102, -108],
        [ -36,    2, -121,  101,  -68, -112,  -82,  -28],
        [  41,   62,   13,  -82,   -5,  -17,  -55,   24],
        [  25,  -23,  -12,    1,   76,  -15,  -18,  -47],
        [ -99,  -98,  -61,   41, -118,  -87,    6,   44],
        [ -67,  -14,  -40,  -69,  105,  -23,  -41,    3],
        [ -18,   15,  -66, -107, -113,   26,    8,  -60],
        [ -64,   54,   39,   62,  -90,   49,  112,  -54],
        [ -37,