In [2]:
import torch
import torch.nn as nn

In [3]:
class CustomModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, stride=1)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.dequant(x)
        return x

net = CustomModel()
net.eval()

CustomModel(
  (quant): QuantStub()
  (conv1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (dequant): DeQuantStub()
)

In [4]:
my_qconfig = torch.quantization.qconfig.QConfig(
        activation=torch.quantization.observer.HistogramObserver.with_args(dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True),
        weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric, reduce_range=False )
)

net.qconfig = my_qconfig
torch.backends.quantized.engine = "fbgemm"
torch.quantization.prepare(net, inplace=True)



CustomModel(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (conv1): Conv2d(
    1, 1, kernel_size=(1, 1), stride=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
)

In [5]:
calibrate_data = torch.randint(low=0, high=255, size=(1, 4, 16), dtype=torch.uint8).unsqueeze(0)
calibrate_data = calibrate_data / 255
_ = net(calibrate_data)

In [6]:
torch.quantization.convert(net, inplace=True)

CustomModel(
  (quant): Quantize(scale=tensor([0.0078]), zero_point=tensor([0]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.005468083545565605, zero_point=127)
  (dequant): DeQuantize()
)

In [7]:
net

CustomModel(
  (quant): Quantize(scale=tensor([0.0078]), zero_point=tensor([0]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=0.005468083545565605, zero_point=127)
  (dequant): DeQuantize()
)

In [8]:
activations = []
def custom_hook(module, input, output):
    info = {
        'module': module,
        'input': input,
        'output': output
    }
    activations.append(info)

for name, module in net.named_modules():
    if len(list(module.children())) == 0:
        module.register_forward_hook(custom_hook)

In [9]:
channel1 = torch.arange(0, 64).view(4, 16).to(torch.uint8).unsqueeze(0)
input_data = channel1.unsqueeze(0)
input_data = input_data / 255

In [10]:
_ = net(input_data)

In [12]:
activations[1]['input'][0].int_repr()
activations[1]['output'].int_repr()

tensor([[[[ 0,  0,  0,  1,  1,  2,  2,  2,  2,  3,  3,  4,  4,  5,  5,  5],
          [ 5,  6,  6,  7,  7,  7,  7,  8,  8,  9,  9, 10, 10, 10, 10, 11],
          [11, 12, 12, 12, 12, 13, 13, 14, 14, 15, 15, 15, 15, 16, 16, 17],
          [17, 17, 17, 18, 18, 19, 19, 20, 20, 20, 20, 21, 21, 22, 22, 22]]]],
       dtype=torch.uint8)

In [13]:
# https://discuss.pytorch.org/t/the-result-of-quantized-conv2d-is-different-from-the-result-i-calculate/157066/6
qx = activations[1]['input'][0].int_repr()
wx = net.conv1.weight().int_repr()

sinput = activations[1]['input'][0].q_scale()
sweight = activations[1]['module'].weight().q_per_channel_scales()[0]
soutput = activations[1]['module'].scale
zinput = activations[1]['input'][0].q_zero_point()
zweight = activations[1]['module'].weight().q_per_channel_zero_points()[0]
zoutput = activations[1]['module'].zero_point

bias = activations[1]['module'].bias()
qbias = torch.round(bias / (sinput * sweight))

qoutput = qx * wx + qbias
qoutput = torch.round(qoutput * sinput * sweight / soutput + zoutput)
qoutput = torch.clamp(qoutput, 0, 127)
print((activations[1]['output'].int_repr() == qoutput).sum())

tensor(64)


In [15]:
sinput

0.0078393230214715

In [14]:
zinput

0

In [16]:
def fixed_point_multiply(value, M0, n):
    # Step 1: Convert M0 to a 32-bit fixed-point integer
    M0_fixed = int(round((2**31) * M0))
    
    # Step 2: Perform the multiplication in fixed-point
    result_fixed = (value * M0_fixed) >> 31
    
    # Step 3: Apply the bit shift for 2^(-n)
    result_shifted = result_fixed >> n
    
    return result_shifted

# Example values
M0 = 0.75  # Example M0 in [0.5, 1)
n = 4      # Example n
value = 1000000  # Example value to be multiplied

# Perform the fixed-point multiplication
result = fixed_point_multiply(value, M0, n)

print(f"Fixed-point multiplication result: {result}")


Fixed-point multiplication result: 46875


In [17]:
M0_fixed = int(round((2**31) * M0))

In [18]:
M0_fixed

1610612736