***Reference: https://github.com/pytorch/pytorch/wiki/torch_quantization_design_proposal***

# ***Step-1: Setup & Import Libraries***

In [None]:
import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time

# ***Step-2: Float Tensor Representation from Netron***


In [None]:
# Here we can put tensors from any trained model (By using NETRON software)

x = torch.tensor([[ 0.8237,  0.5781,  0.6879],
        [ 0.3816,  0.7249,  0.0998]])

torch.save(x,'float_tensor.pt')
print('Float tensor (bytes)',os.path.getsize('float_tensor.pt'))


Float tensor (bytes) 747


***--Maximum Value and minimum value of x***

In [None]:
b = torch.max(x)
print(b)

tensor(0.8237)


In [None]:
a = torch.min(x)
print(a)

tensor(0.0998)


In [None]:
# scale value

scale = (b-a)/255

print(scale)

tensor(0.0028)


In [None]:
# zero point

zero_point = torch.round(-a*255/(b-a))

print(zero_point)

tensor(-35.)


# ***Step-3: Apply Quantization per tensor (Affine mapping)***

***I. QScheme (torch.qscheme):*** a enum that specifies the way we quantize the Tensor--
  ***(a)*** torch.per_tensor_affine
  ***(b)*** torch.per_tensor_symmetric
  ***(c)*** torch.per_channel_affine
  ***(d)*** torch.per_channel_symmetric

***Reference:*** https://pytorch.org/docs/stable/quantization.html#:~:text=PyTorch%20supports%20both%20per%20tensor,with%20the%20same%20quantization%20parameters.

In [None]:
# Datatypes supported for quantized tensor
# dtype (torch.dtype): data type of the quantized Tensor
# torch.quint8
# torch.qint8
# torch.qint32
# torch.float16

xq = torch.quantize_per_tensor(x, scale = 0.001, zero_point =4, dtype=torch.quint8)
print(xq)
torch.save(xq,'qtz_tensor.pt')
print('Quantized tensor (bytes)',os.path.getsize('qtz_tensor.pt'))
print(xq.int_repr())
print(xq.dtype)

tensor([[0.2510, 0.2510, 0.2510],
        [0.2510, 0.2510, 0.1000]], size=(2, 3), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.001, zero_point=4)
Quantized tensor (bytes) 811
tensor([[255, 255, 255],
        [255, 255, 104]], dtype=torch.uint8)
torch.quint8


# ***Step-4: Apply Dequantization***

In [None]:
xdq = xq.dequantize()
print(xdq)
print(xdq.dtype)

tensor([[0.2510, 0.2510, 0.2510],
        [0.2510, 0.2510, 0.1000]])
torch.float32


# ***Step-5: MAE/MSE loss between x and xdq***

***I. MAE loss***

In [None]:
# Import the required libraries
import torch
import torch.nn as nn

# print input and target tensors
print("Input Tensor:\n", x)
print("Target Tensor:\n", xdq)

# create a criterion to measure the mean absolute error
mae = nn.L1Loss()

# compute the loss (mean absolute error)
output = mae(x, xdq)

# output.backward()
print("MAE loss:", output)

Input Tensor:
 tensor([[0.8237, 0.5781, 0.6879],
        [0.3816, 0.7249, 0.0998]])
Target Tensor:
 tensor([[0.2510, 0.2510, 0.2510],
        [0.2510, 0.2510, 0.1000]])
MAE loss: tensor(0.3236)


***II. MSE Loss***

In [None]:
# Import the required libraries
import torch
import torch.nn as nn

# print input and target tensors
print("Input Tensor:\n", x)
print("Target Tensor:\n", xdq)

# create a criterion to measure the mean squared error
mse = nn.MSELoss()

# compute the loss (mean squared error)
output = mse(x, xdq)

# output.backward()
print("MSE loss:", output)

Input Tensor:
 tensor([[0.8237, 0.5781, 0.6879],
        [0.3816, 0.7249, 0.0998]])
Target Tensor:
 tensor([[0.2510, 0.2510, 0.2510],
        [0.2510, 0.2510, 0.1000]])
MSE loss: tensor(0.1446)


In [None]:
proto_tensor = tf.make_tensor_proto(X)

In [None]:
xdq = tf.make_ndarray(proto_tensor)
print(xdq)

In [None]:
Xdq = torch.from_numpy(xdq)
print(Xdq)
Xdq.dtype

tensor([[0.0000e+00, 1.0000e+00, 2.0000e+00,  ..., 2.5400e+02, 2.5500e+02,
         2.5600e+02],
        [2.5700e+02, 2.5800e+02, 2.5900e+02,  ..., 5.1100e+02, 5.1200e+02,
         5.1300e+02],
        [5.1400e+02, 5.1500e+02, 5.1600e+02,  ..., 7.6800e+02, 7.6900e+02,
         7.7000e+02],
        ...,
        [1.1822e+04, 1.1823e+04, 1.1824e+04,  ..., 1.2076e+04, 1.2077e+04,
         1.2078e+04],
        [1.2079e+04, 1.2080e+04, 1.2081e+04,  ..., 1.2333e+04, 1.2334e+04,
         1.2335e+04],
        [1.2336e+04, 1.2337e+04, 1.2338e+04,  ..., 1.2590e+04, 1.2591e+04,
         1.2592e+04]])


torch.float32