# Fully Connected

In [1]:
import torch as th
import warnings

warnings.filterwarnings("ignore")
th.backends.quantized.engine = "qnnpack"  # for ARM CPU
th.manual_seed(0)

<torch._C.Generator at 0x11929ef90>

In [2]:
class Model(th.nn.Module):
    def __init__(self, input_dims, output_dims, bias=False):
        super().__init__()
        self.quant = th.ao.quantization.QuantStub()
        self.fc = th.nn.Linear(input_dims, output_dims, bias=bias)
        self.dequant = th.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

In [3]:
input_dims = 10
output_dims = 15
m = Model(input_dims, output_dims, bias=True)

m.qconfig = th.ao.quantization.QConfig(
    activation=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
    weight=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
)

# Prepare
pm = th.ao.quantization.prepare_qat(m)

# Train
pm(th.rand(32, input_dims))

# Convert
qm = th.ao.quantization.convert(pm.eval())

In [4]:
# Test
x = th.rand(1, input_dims)
xq = qm.quant(x)
y = qm(x)
yq = th.round(y / qm.fc.scale)

print(f"Float input: {x}\n")
print(f"Quantized input: {xq.int_repr()}\n")
print(f"Float output: {y}\n")
print(f"Quantized output: {yq}\n")
print(
    f"Multiplier: {th.round((qm.quant.scale * qm.fc.weight().q_scale() / qm.fc.scale) * (2**31))}\n"
)
print(f"Quantized weights: {qm.fc.weight().int_repr().flatten()}\n")
print(
    f"Quantized bias: {th.round(qm.fc.bias() / (qm.quant.scale * qm.fc.weight().q_scale())).int()}\n"
)

Float input: tensor([[0.7720, 0.2957, 0.9200, 0.1559, 0.0801, 0.2745, 0.5808, 0.9604, 0.2613,
         0.6788]])

Quantized input: tensor([[ 98,  38, 117,  20,  10,  35,  74, 122,  33,  87]], dtype=torch.int8)

Float output: tensor([[-0.2249, -0.5709, -0.1384,  0.1470,  0.3287, -0.3373,  0.5276, -0.5709,
         -0.3114,  0.1211, -0.2768, -0.2595, -0.7784, -0.3546,  0.3633]])

Quantized output: tensor([[-26., -66., -16.,  17.,  38., -39.,  61., -66., -36.,  14., -32., -30.,
         -90., -41.,  42.]])

Multiplier: tensor([4817036.])

Quantized weights: tensor([  -1,   69, -105,  -94,  -49,   34,   -3,  101,  -11,   34,  -39,  -25,
        -122,  -85,  -53,    5,   51,   77,  -87,  -56,   46,  106,  -26,   96,
         -21,   14,  116, -119,  -80,  -32,  -50,  110,  -83,  -59,  -89, -120,
         -75,  110,   57,   62,    7,  -66,   22, -119,  -92,  -66,   81,   75,
         -57,   -5,   82,  127,   51,   17,   86,  -75,   24,  -99,  -89,  -66,
          58,   51,  -76,   39,   70,  