# Fully Connected

In [1]:
import torch as th
import warnings

warnings.filterwarnings("ignore")
th.backends.quantized.engine = "qnnpack"  # for ARM CPU
th.manual_seed(0)

<torch._C.Generator at 0x104ffaf90>

In [2]:
class Model(th.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
        bias=False,
    ):
        super().__init__()
        self.quant = th.ao.quantization.QuantStub()
        self.conv = th.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.dequant = th.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

In [3]:
input_channels = 2
output_channels = 2
input_width = 5
input_height = 5
kernel_size = 3
stride = 2
padding = 1
dilation = 1
groups = 1
bias = True
m = Model(
    input_channels,
    output_channels,
    kernel_size,
    stride,
    padding,
    dilation,
    groups,
    bias=bias,
)

m.qconfig = th.ao.quantization.QConfig(
    activation=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
    weight=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
)

# Prepare
pm = th.ao.quantization.prepare_qat(m)

# Train
pm(th.rand(32, input_channels, input_width, input_height))

# Convert
qm = th.ao.quantization.convert(pm.eval())

In [4]:
# Test
x = th.rand(1, input_channels, input_width, input_height)
xq = qm.quant(x)
y = qm(x)
yq = th.round(y / qm.conv.scale)
pdims = (0, 2, 3, 1)  # permutation dimensions

print(f"Float input: {x.flatten()}\n")
print(f"Quantized input: {xq.int_repr()}\n")
print(f"Permuted Quantized input: {xq.int_repr().permute(pdims).flatten()}\n")
print(f"Float output: {y}\n")
print(f"Quantized output: {yq}\n")
print(f"Permuted Quantized output: {yq.permute(pdims).flatten()}\n")
print(
    f"Multiplier: {th.round((qm.quant.scale * qm.conv.weight().q_scale() / qm.conv.scale) * (2**31))}\n"
)
print(f"Quantized weights: {qm.conv.weight().int_repr().int()}\n")
print(
    f"Permuted Quantized weights: {qm.conv.weight().int_repr().permute(pdims).flatten().int()}\n"
)
print(
    f"Quantized bias: {th.round(qm.conv.bias() / (qm.quant.scale * qm.conv.weight().q_scale())).int()}\n"
)

Float input: tensor([0.9795, 0.3072, 0.0535, 0.2620, 0.0360, 0.9120, 0.2755, 0.1482, 0.4460,
        0.1631, 0.2388, 0.7549, 0.1246, 0.7358, 0.8646, 0.4989, 0.7220, 0.2199,
        0.8899, 0.9857, 0.6188, 0.9556, 0.0034, 0.4451, 0.5993, 0.0654, 0.7433,
        0.3094, 0.3192, 0.1033, 0.7044, 0.1362, 0.3939, 0.7221, 0.8767, 0.6131,
        0.3135, 0.2825, 0.7806, 0.5928, 0.9209, 0.6883, 0.9028, 0.0971, 0.9020,
        0.2703, 0.3662, 0.5030, 0.4062, 0.5989])

Quantized input: tensor([[[[125,  39,   7,  33,   5],
          [116,  35,  19,  57,  21],
          [ 30,  96,  16,  94, 110],
          [ 64,  92,  28, 113, 126],
          [ 79, 122,   0,  57,  76]],

         [[  8,  95,  39,  41,  13],
          [ 90,  17,  50,  92, 112],
          [ 78,  40,  36, 100,  76],
          [117,  88, 115,  12, 115],
          [ 34,  47,  64,  52,  76]]]], dtype=torch.int8)

Permuted Quantized input: tensor([125,   8,  39,  95,   7,  39,  33,  41,   5,  13, 116,  90,  35,  17,
         19,  50,  57,