# Fully Connected

In [1]:
import torch as th
import warnings

warnings.filterwarnings("ignore")
th.backends.quantized.engine = "qnnpack"  # for ARM CPU
th.manual_seed(0)

<torch._C.Generator at 0x122782f90>

In [2]:
class Model(th.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
        bias=False,
    ):
        super().__init__()
        self.quant = th.ao.quantization.QuantStub()
        self.conv = th.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.dequant = th.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

In [3]:
input_width = 3
input_height = 3
input_channels = 1
output_channels = 1
kernel_size = 3
stride = 1
padding = 1
dilation = 1
groups = 1
bias = False
m = Model(
    input_channels,
    output_channels,
    kernel_size,
    stride,
    padding,
    dilation,
    groups,
    bias=bias,
)

m.qconfig = th.ao.quantization.QConfig(
    activation=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
    weight=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
)

# Prepare
pm = th.ao.quantization.prepare_qat(m)

# Train
pm(th.rand(32, input_channels, input_width, input_height))

# Convert
qm = th.ao.quantization.convert(pm.eval())

In [4]:
# Test
x = th.rand(1, input_channels, input_width, input_height)
xq = qm.quant(x)
y = qm(x)
yq = th.round(y / qm.conv.scale)

print(f"Float input: {x.flatten()}\n")
print(f"Quantized input: {xq.int_repr()}\n")
print(f"Flattened Quantized input: {xq.int_repr().flatten()}\n")
print(f"Float output: {y.flatten()}\n")
print(f"Quantized output: {yq.flatten()}\n")
print(
    f"Multiplier: {th.round((qm.quant.scale * qm.conv.weight().q_scale() / qm.conv.scale) * (2**31))}\n"
)
print(f"Quantized weights: {qm.conv.weight().int_repr().int()}\n")
print(
    f"Permuted Quantized weights: {qm.conv.weight().int_repr().permute(1, 0, 2, 3).flatten().int()}\n"
)
print(
    f"Quantized bias: {th.round(qm.conv.bias() / (qm.quant.scale * qm.conv.weight().q_scale())).int()}\n"
)

Float input: tensor([0.5832, 0.7130, 0.6979, 0.4371, 0.0901, 0.4229, 0.6737, 0.3176, 0.6898])

Quantized input: tensor([[[[75, 91, 89],
          [56, 12, 54],
          [86, 41, 88]]]], dtype=torch.int8)

Flattened Quantized input: tensor([75, 91, 89, 56, 12, 54, 86, 41, 88], dtype=torch.int8)

Float output: tensor([ 0.1028, -0.1625, -0.1526,  0.0299, -0.0862,  0.2256, -0.0066, -0.2455,
        -0.0929])

Quantized output: tensor([ 31., -49., -46.,   9., -26.,  68.,  -2., -74., -28.])

Multiplier: tensor([10893229.])

Quantized weights: tensor([[[[  -1,   83, -128],
          [-114,  -60,   42],
          [  -3,  123,  -14]]]], dtype=torch.int32)

Permuted Quantized weights: tensor([  -1,   83, -128, -114,  -60,   42,   -3,  123,  -14],
       dtype=torch.int32)

Quantized bias: tensor([0], dtype=torch.int32)



In [5]:
qm.conv.weight()

tensor([[[[-0.0022,  0.1786, -0.2754],
          [-0.2453, -0.1291,  0.0904],
          [-0.0065,  0.2647, -0.0301]]]], size=(1, 1, 3, 3), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0021517518907785416,
       zero_point=0)