# Fully Connected

In [1]:
import torch as th
import warnings

warnings.filterwarnings("ignore")
th.backends.quantized.engine = "qnnpack"  # for ARM CPU
th.manual_seed(0)

<torch._C.Generator at 0x10bb82f90>

In [2]:
class Model(th.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
        bias=False,
    ):
        super().__init__()
        self.quant = th.ao.quantization.QuantStub()
        self.conv = th.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.dequant = th.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

In [3]:
input_channels = 2
output_channels = 2
input_width = 3
input_height = 3
kernel_size = 3
stride = 1
padding = 1
dilation = 1
groups = 1
bias = True
m = Model(
    input_channels,
    output_channels,
    kernel_size,
    stride,
    padding,
    dilation,
    groups,
    bias=bias,
)

m.qconfig = th.ao.quantization.QConfig(
    activation=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
    weight=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
)

# Prepare
pm = th.ao.quantization.prepare_qat(m)

# Train
pm(th.rand(32, input_channels, input_width, input_height))

# Convert
qm = th.ao.quantization.convert(pm.eval())

In [4]:
# Test
x = th.rand(1, input_channels, input_width, input_height)
xq = qm.quant(x)
y = qm(x)
yq = th.round(y / qm.conv.scale)
pdims = (0, 2, 3, 1)  # permutation dimensions

print(f"Float input: {x.flatten()}\n")
print(f"Quantized input: {xq.int_repr()}\n")
print(f"Permuted Quantized input: {xq.int_repr().permute(pdims).flatten()}\n")
print(f"Float output: {y}\n")
print(f"Quantized output: {yq}\n")
print(f"Permuted Quantized output: {yq.permute(pdims).flatten()}\n")
print(
    f"Multiplier: {th.round((qm.quant.scale * qm.conv.weight().q_scale() / qm.conv.scale) * (2**31))}\n"
)
print(f"Quantized weights: {qm.conv.weight().int_repr().int()}\n")
print(
    f"Permuted Quantized weights: {qm.conv.weight().int_repr().permute(pdims).flatten().int()}\n"
)
print(
    f"Quantized bias: {th.round(qm.conv.bias() / (qm.quant.scale * qm.conv.weight().q_scale())).int()}\n"
)

Float input: tensor([0.9200, 0.1559, 0.0801, 0.2745, 0.5808, 0.9604, 0.2613, 0.6788, 0.3746,
        0.3916, 0.8677, 0.1125, 0.5531, 0.9702, 0.4313, 0.8882, 0.3460, 0.9025])

Quantized input: tensor([[[[117,  20,  10],
          [ 35,  74, 122],
          [ 33,  87,  48]],

         [[ 50, 111,  14],
          [ 71, 124,  55],
          [113,  44, 115]]]], dtype=torch.int8)

Permuted Quantized input: tensor([117,  50,  20, 111,  10,  14,  35,  71,  74, 124, 122,  55,  33, 113,
         87,  44,  48, 115], dtype=torch.int8)

Float output: tensor([[[[-0.1278, -0.2940, -0.1598],
          [-0.1214, -0.2493, -0.4027],
          [-0.4474, -0.7287, -0.3579]],

         [[-0.0447,  0.4091,  0.0000],
          [-0.0064, -0.1662, -0.1726],
          [ 0.3196, -0.1790,  0.0000]]]])

Quantized output: tensor([[[[ -20.,  -46.,  -25.],
          [ -19.,  -39.,  -63.],
          [ -70., -114.,  -56.]],

         [[  -7.,   64.,    0.],
          [  -1.,  -26.,  -27.],
          [  50.,  -28.,    0.]