# Fully Connected

In [1]:
import torch as th
import warnings

warnings.filterwarnings("ignore")
th.backends.quantized.engine = "qnnpack"  # for ARM CPU
th.manual_seed(0)

<torch._C.Generator at 0x10c982f90>

In [2]:
class Model(th.nn.Module):
    def __init__(self, input_dims, output_dims):
        super().__init__()
        self.quant = th.ao.quantization.QuantStub()
        self.fc = th.nn.Linear(input_dims, output_dims, bias=False)
        self.dequant = th.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

In [3]:
input_dims = 10
output_dims = 15
m = Model(input_dims, output_dims)

m.qconfig = th.ao.quantization.QConfig(
    activation=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
    weight=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
)

# Prepare
pm = th.ao.quantization.prepare_qat(m)

# Train
pm(th.rand(32, input_dims))

# Convert
qm = th.ao.quantization.convert(pm.eval())

In [5]:
# Test
x = th.rand(1, input_dims)
xq = qm.quant(x)
y = qm(x)
yq = th.round(y / qm.fc.scale)

print(f"Float input: {x}\n")
print(f"Quantized input: {xq.int_repr()}\n")
print(f"Float output: {y}\n")
print(f"Quantized output: {yq}\n")
print(
    f"Multiplier: {th.round((qm.quant.scale * qm.fc.weight().q_scale() / qm.fc.scale) * (2**31))}\n"
)
print(f"Quantized weights: {qm.fc.weight().int_repr().flatten()}\n")
print(f"Quantized bias: {qm.fc.bias()}\n")

Float input: tensor([[0.8140, 0.6299, 0.6581, 0.5464, 0.6864, 0.3782, 0.3011, 0.0326, 0.1233,
         0.7167]])

Quantized input: tensor([[104,  80,  84,  70,  88,  48,  38,   4,  16,  91]], dtype=torch.int8)

Float output: tensor([[-0.1863, -0.5974,  0.3148, -0.3276, -0.3918,  0.4111,  0.4754,  0.1221,
         -0.1991,  0.0000, -0.0193,  0.1285, -0.4175,  0.1413,  0.1734]])

Quantized output: tensor([[-29., -93.,  49., -51., -61.,  64.,  74.,  19., -31.,   0.,  -3.,  20.,
         -65.,  22.,  27.]])

Multiplier: tensor([6485986.])

Quantized weights: tensor([  -1,   69, -105,  -94,  -49,   34,   -3,  101,  -11,   34,  -39,  -25,
        -122,  -85,  -53,    5,   51,   77,  -87,  -56,   46,  106,  -26,   96,
         -21,   14,  116, -119,  -80,  -32,  -50,  110,  -83,  -59,  -89, -120,
         -75,  110,   57,   62,    7,  -66,   22, -119,  -92,  -66,   81,   75,
         -57,   -5,   82,  127,   51,   17,   86,  -75,   24,  -99,  -89,  -66,
          58,   51,  -76,   39,   70,  