# Fully Connected

In [1]:
import torch as th
import warnings

warnings.filterwarnings("ignore")
th.backends.quantized.engine = "qnnpack"  # for ARM CPU
th.manual_seed(0)

<torch._C.Generator at 0x117082f90>

In [2]:
class Model(th.nn.Module):
    def __init__(self, input_dims, output_dims):
        super().__init__()
        self.quant = th.ao.quantization.QuantStub()
        self.fc = th.nn.Linear(input_dims, output_dims, bias=False)
        self.dequant = th.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

In [3]:
input_dims = 3
output_dims = 2
m = Model(input_dims, output_dims)

m.qconfig = th.ao.quantization.QConfig(
    activation=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
    weight=th.ao.quantization.MovingAverageMinMaxObserver.with_args(
        quant_min=-128,
        quant_max=127,
        dtype=th.qint8,
        qscheme=th.per_tensor_symmetric,
        reduce_range=False,
    ),
)

# Prepare
pm = th.ao.quantization.prepare_qat(m)

# Train
pm(th.rand(32, 3))

# Convert
qm = th.ao.quantization.convert(pm.eval())

In [4]:
qm.dequant

DeQuantize()

In [5]:
# Test
x = th.rand(1, 3)
xq = qm.quant(x)
y = qm(x)
yq = th.round(y / qm.fc.scale)

print(f"Float input: {x}\n")
print(f"Quantized input: {xq.int_repr()}\n")
print(f"Float output: {y}\n")
print(f"Quantized output: {yq}\n")
print(
    f"Multiplier: {th.round((qm.quant.scale * qm.fc.weight().q_scale() / qm.fc.scale) * (2**31))}\n"
)
print(f"Quantized weights: {qm.fc.weight().int_repr()}\n")
print(f"Quantized bias: {qm.fc.bias()}\n")

Float input: tensor([[0.4725, 0.5751, 0.2952]])

Quantized input: tensor([[60, 74, 38]], dtype=torch.int8)

Float output: tensor([[ 0.0374, -0.2808]])

Quantized output: tensor([[  8., -60.]])

Multiplier: tensor([13372818.])

Quantized weights: tensor([[  -1,   83, -127],
        [-114,  -60,   42]], dtype=torch.int8)

Quantized bias: tensor([0., 0.])

