# An overview of QuantTensor and QuantConv2d

In this initial tutorial, we take a first look at `QuantTensor`, a basic data structure in Brevitas, and at `QuantConv2d`, a typical quantized layer. `QuantConv2d` is an instance of a `QuantWeightBiasInputOutputLayer` (typically imported as `QuantWBIOL`), meaning that it supports quantization of its weight, bias, input and output. Other instances of `QuantWBIOL` are `QuantLinear`, `QuantConv1d`, `QuantConvTranspose1d` and `QuantConvTranspose2d`, and they all follow the same principles.

If we take a look at the `__init__` method of `QuantConv2d`, we notice a few things:

In [1]:
import inspect
from brevitas.nn import QuantConv2d
from IPython.display import Markdown, display

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))
    
source = inspect.getsource(QuantConv2d.__init__)  
pretty_print_source(source)

```python
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Union[int, Tuple[int, int]] = 1,
            padding: Union[int, Tuple[int, int]] = 0,
            dilation: Union[int, Tuple[int, int]] = 1,
            groups: int = 1,
            bias: bool = True,
            padding_type: str = 'standard',
            weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
            bias_quant: Optional[BiasQuantType] = None,
            input_quant: Optional[ActQuantType] = None,
            output_quant: Optional[ActQuantType] = None,
            return_quant_tensor: bool = False,
            **kwargs) -> None:
        Conv2d.__init__(
            self,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        QuantWBIOL.__init__(
            self,
            weight_quant=weight_quant,
            bias_quant=bias_quant,
            input_quant=input_quant,
            output_quant=output_quant,
            return_quant_tensor=return_quant_tensor,
            **kwargs)
        assert self.padding_mode == 'zeros'
        assert not (padding_type == 'same' and padding != 0)
        self.padding_type = padding_type

```

`QuantConv2d` is an instance of both `Conv2d` and `QuantWBIOL`. Its initialization method exposes the usual arguments of a `Conv2d`, as well as: an extra flag to support *same padding*; *four* different arguments to set a quantizer for - respectively - *weight*, *bias*, *input*, and *output*; a `return_quant_tensor` boolean flag; the `**kwargs` placeholder to intercept additional arbitrary keyword arguments.  
In this tutorial we will focus on how to set the four quantizer arguments and the return flags; arbitrary kwargs will be explained in a separate tutorial dedicated to defining and overriding quantizers.

By default `weight_quant=Int8WeightPerTensorFloat`, while `bias_quant`, `input_quant` and `output_quant` are set to `None`. That means that by default weights are quantized to *8-bit signed integer with a per-tensor floating-point scale factor* (a very common type of quantization adopted by e.g. the ONNX standard opset), while quantization of bias, input, and output are disabled. We can easily verify all of this at runtime on an example:

In [2]:
default_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=False)

In [3]:
print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')
print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')
print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')
print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')

Is weight quant enabled: True
Is bias quant enabled: False
Is input quant enabled: False
Is output quant enabled: False


If we now try to pass in a random floating-point tensor as input, as expected we get the output of the convolution:

In [4]:
import torch

out = default_quant_conv(torch.randn(1, 2, 5, 5))
out

tensor([[[[-0.1007, -0.2631,  0.9119],
          [ 0.3060,  0.3174,  0.6748],
          [-0.2179, -0.4119, -0.4807]],

         [[ 0.4296,  0.0781, -0.2309],
          [ 0.3522, -0.6440, -0.5089],
          [ 0.8133,  0.3387, -0.0395]],

         [[-0.7194,  0.9901,  0.5440],
          [-1.1865,  1.5809, -1.0971],
          [-0.7248, -0.1470, -0.0498]]]], grad_fn=<ThnnConv2DBackward>)

In this case we are computing the convolution between an unquantized input tensor and quantized weights, so the output in general is unquantized.

A QuantConv2d with quantization disabled everywhere behaves like a standard `Conv2d`. Again can easily verify this:

In [5]:
from torch.nn import Conv2d

torch.manual_seed(0)  # set a seed to make sure the random weight init is reproducible
disabled_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=False, weight_quant=None)
torch.manual_seed(0)  # reproduce the same random weight init as above
float_conv = Conv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=False)
inp = torch.randn(1, 2, 5, 5)
torch.isclose(disabled_quant_conv(inp), float_conv(inp)).all().item()

True

As we have just seen, Brevitas allows users as much freedom as possible to experiment with quantization, meaning that computation between quantized and unquantized values is considered legal. This allows users to mix Brevitas layers with Pytorch layers with little restrictions.  
To make this possible, quantized values are typically represented in *dequantized format*, meaning that - in the case of affine quantization implemented in Brevitas - zero-point and scale factor are applied to their integer values according to the formula **quant_value = (integer_value - zero_point) * scale**.

## QuantTensor

We can directly observe the quantized weights by calling the weight quantizer on the layer's weights: `default_quant_conv.weight_quant(quant_conv.weight)`, which for shortness is already implemented as `default_quant_conv.quant_weight()` :

In [6]:
default_quant_conv.quant_weight()

QuantTensor(value=tensor([[[[-0.1684, -0.0722,  0.1554],
          [-0.1499, -0.1554, -0.1332],
          [-0.1665,  0.1388,  0.2220]],

         [[-0.1277,  0.1813, -0.2294],
          [ 0.0740, -0.0259, -0.1628],
          [ 0.1425, -0.1906, -0.2109]]],


        [[[ 0.0463,  0.2350, -0.1480],
          [-0.2350,  0.1221, -0.0074],
          [ 0.1369,  0.0814, -0.0185]],

         [[-0.0648,  0.1684, -0.1517],
          [ 0.1628,  0.1517,  0.1998],
          [-0.0130,  0.2257, -0.1221]]],


        [[[-0.1758,  0.1166, -0.0592],
          [ 0.1425, -0.0796, -0.1499],
          [-0.1832, -0.0278, -0.2294]],

         [[ 0.2054, -0.0296,  0.1702],
          [ 0.0185, -0.2294,  0.0426],
          [ 0.0352,  0.0037, -0.1258]]]], grad_fn=<MulBackward0>), scale=tensor(0.0019, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))

Notice how the quantized weights are wrapped in a data structure implemented by Brevitas called `QuantTensor`. A `QuantTensor` is a way to represent an affine quantized tensor with all its metadata, meaning: the `value` of the quantized tensor in *dequantized* format, `scale`, `zero_point`, `bit_width`, whether the quantized value it's `signed` or not, and whether the tensor was generated in `training` mode. 

As expected, we have that the quantized value (in dequantized format) can be computer from its integer representation, together with zero-point and scale:

In [7]:
int_weight = default_quant_conv.int_weight()
zero_point = default_quant_conv.quant_weight_zero_point()
scale = default_quant_conv.quant_weight_scale()
quant_weight_manually = (int_weight - zero_point) * scale
default_quant_conv.quant_weight().value.isclose(quant_weight_manually).all().item()

True

A *valid* QuantTensor correctly populates all its fields with values `!= None` and respect the **affine quantization invariant**, i.e. `value / scale + zero_point` is (accounting for rounding errors) an *integer* that can be represented within the interval defined by the `bit_width` and `signed` fields of the `QuantTensor`. A *non-valid* one doesn't.
We can observe that the quantized weights are indeed marked as valid:

In [8]:
default_quant_conv.quant_weight().is_valid

True

Calling `is_valid` is relative expensive, so it should be using sparingly, but there are a few cases where a non-valid QuantTensor might be generated that is important to be aware of. Say we instantiate the layer again, this time with `return_quant_tensor=True`:

In [9]:
return_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=False, return_quant_tensor=True)

We then again pass as input a random floating-point tensor. Because `input_quant=None` and `output_quant=None` (i.e. both input and output quantization are disabled), again as before we are performing a convolution between a quantized and an unquantized tensor, which in general returns an unquantized tensor:

In [10]:
out_tensor = return_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor

QuantTensor(value=tensor([[[[-0.6238, -0.1567,  0.5639],
          [-0.3426,  0.0662,  0.6296],
          [-0.6507,  0.4468, -0.4465]],

         [[ 1.0313,  0.7856, -0.2931],
          [-0.6213,  0.5228,  0.7288],
          [ 0.1397,  0.0216,  0.7518]],

         [[ 0.0763, -0.3561, -0.0491],
          [-0.5127,  0.2945,  0.5501],
          [-0.2460, -0.0052,  0.0395]]]], grad_fn=<ThnnConv2DBackward>), scale=None, zero_point=None, bit_width=None, signed_t=None, training_t=tensor(True))

Because we set `return_quant_tensor=True`, we get a `QuantTensor` as output object. However, we observe that `scale`, `zero_point` and `bit_width` of the output `QuantTensor` are set to `None`. This is expected since the output tensor is unquantized. In this case then the `QuantTensor` is really just acting as a wrapper around a `torch.Tensor`, and as such is market as non-valid.

In [11]:
out_tensor.is_valid

False

`QuantTensor` implements `__torch_function__` to handle being called from torch functional operators (e.g. ops under `torch.nn.functional`). Passing a QuantTensor to supported ops that are invariant to quantization, e.g. max-pooling, preserve the the validity of a QuantTensor. Example:

In [12]:
import torch
from brevitas.nn import QuantIdentity

quant_identity = QuantIdentity(return_quant_tensor=True)
quant_tensor = quant_identity(torch.randn(1, 3, 4, 4))
torch.nn.functional.max_pool2d(quant_tensor, kernel_size=2, stride=2)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


QuantTensor(value=tensor([[[[ 1.4362,  0.5809],
          [ 1.0327,  1.9041]],

         [[ 0.5809, -0.1452],
          [ 1.3878,  2.0494]],

         [[ 0.7100,  0.5809],
          [ 1.1780,  0.9359]]]]), scale=tensor(0.0161), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))

For ops that are not invariant to quantization, a `QuantTensor` decays into a floating-point `torch.Tensor`. Example:

In [13]:
torch.tanh(quant_tensor)

tensor([[[[-0.9023, -0.9424,  0.0805, -0.7936],
          [ 0.8929, -0.1125,  0.5233, -0.1600],
          [-0.0645,  0.7750,  0.7181, -0.4366],
          [-0.7021,  0.1442,  0.9566,  0.3690]],

         [[-0.0323, -0.6678, -0.4623, -0.3408],
          [ 0.5233, -0.8216, -0.1442, -0.1913],
          [-0.0805,  0.8827,  0.5350, -0.4495],
          [-0.3550, -0.8676,  0.9347,  0.9674]],

         [[-0.2221,  0.6107,  0.5233,  0.2676],
          [-0.9684,  0.4873, -0.4873, -0.7021],
          [-0.0161,  0.8268,  0.3829, -0.1913],
          [-0.8216, -0.3120,  0.7333, -0.8163]]]])

## Input Quantization

We can obtain a valid output `QuantTensor` by making sure that both input and weight of `QuantConv2d` are quantized. To do so, we can set a quantizer for `input_quant`. In this example we pick a *signed 8-bit* quantizer with *per-tensor floating-point scale factor*:

In [14]:
from brevitas.quant.scaled_int import Int8ActPerTensorFloat

input_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=False, 
    input_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
out_tensor = input_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor

QuantTensor(value=tensor([[[[-0.1760, -0.3239,  0.8647],
          [ 0.2300, -0.9457, -0.5969],
          [-0.1486,  0.2389, -0.1381]],

         [[ 0.4634, -0.4049, -0.3049],
          [-0.0643,  1.0154,  0.6058],
          [-0.0367, -0.9156,  0.1461]],

         [[ 0.7388,  0.6103,  0.8035],
          [ 0.3011,  1.0519, -0.2473],
          [-0.3113, -1.8302,  0.0223]]]], grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[3.8195e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))

In [15]:
out_tensor.is_valid

True

What happens internally is that the input tensor passed to `input_quant_conv` is being quantized before being passed to the convolution operator. That means we are now computing a convolution between two quantized tensors, which mimplies that the output of the operation is also quantized. As expected then `out_tensor` is marked as valid. 

Another important thing to notice is how the `bit_width` field of `out_tensor` is relatively high at *21 bits*. In Brevitas, the assumption is always that the output bit-width of an operator reflects the worst-case size of the *accumulator* required by that operation. In other terms, given the *size* of the input and weight tensors and their *bit-widths*, 21 is the bit-width that would be required to represent the largest possible output value that could be generated. This makes sure that the affine quantization invariant is always respected.

We could have obtained a similar result by directly passing as input a QuantTensor. In this example we are directly defining a QuantTensor ourselves, but it could also be the output of a previous layer.

In [16]:
from brevitas.quant_tensor import QuantTensor

scale = 0.0001
bit_width = 8
zero_point = 0.
int_value = torch.randint(low=- 2 ** (bit_width - 1), high=2 ** (bit_width - 1) - 1, size=(1, 2, 5, 5))
quant_value = (int_value - zero_point) * scale
quant_tensor_input = QuantTensor(
    quant_value, 
    scale=torch.tensor(scale), 
    zero_point=torch.tensor(zero_point), 
    bit_width=torch.tensor(float(bit_width)),
    signed=True,
    training=True)
quant_tensor_input

QuantTensor(value=tensor([[[[-1.1500e-02, -5.8000e-03, -9.3000e-03,  1.0000e-02,  3.5000e-03],
          [-6.8000e-03,  1.1500e-02, -1.0600e-02, -1.5000e-03, -1.9000e-03],
          [ 2.9000e-03,  9.5000e-03,  7.2000e-03, -3.7000e-03,  7.7000e-03],
          [-2.4000e-03, -8.9000e-03, -1.2000e-02, -8.1000e-03,  7.2000e-03],
          [-1.1300e-02, -9.7000e-03, -1.0000e-03,  1.0100e-02,  3.8000e-03]],

         [[-1.1900e-02,  6.9000e-03,  8.3000e-03,  1.0000e-04, -6.9000e-03],
          [ 3.9000e-03, -5.4000e-03,  1.1300e-02, -6.0000e-03,  9.7000e-03],
          [ 0.0000e+00,  1.0900e-02, -1.0900e-02,  1.1400e-02, -6.4000e-03],
          [ 9.2000e-03,  7.1000e-03, -6.0000e-04,  9.2000e-03, -8.5000e-03],
          [ 5.0000e-03,  6.5000e-03, -8.3000e-03, -1.2000e-03,  7.4000e-03]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))

In [17]:
quant_tensor_input.is_valid

True

**Note**: how we are explicitly forcing `value`, `scale`, `zero_point` and `bit_width` to be floating-point `torch.Tensor`, as this is expected by Brevitas but it's currently not enforced automatically at initialization time.

If we now pass in `quant_tensor_input` to `return_quant_conv`, we will see that indeed the output is a valid 21-bit `QuantTensor`:

In [18]:
out_tensor = return_quant_conv(quant_tensor_input)
out_tensor

QuantTensor(value=tensor([[[[ 0.0002, -0.0017, -0.0038],
          [-0.0020,  0.0038,  0.0001],
          [ 0.0022,  0.0034, -0.0024]],

         [[-0.0010, -0.0078,  0.0010],
          [ 0.0002,  0.0014,  0.0059],
          [ 0.0050, -0.0020, -0.0069]],

         [[ 0.0006, -0.0071, -0.0012],
          [-0.0011,  0.0057,  0.0055],
          [-0.0003,  0.0040, -0.0019]]]], grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[1.7899e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))

In [19]:
out_tensor.is_valid

True

We can also pass in an input `QuantTensor` to a layer that has `input_quant` enabled. In that case, the input gets re-quantized:

In [20]:
input_quant_conv(quant_tensor_input)

QuantTensor(value=tensor([[[[ 0.0045, -0.0035, -0.0022],
          [ 0.0027,  0.0001,  0.0051],
          [ 0.0030, -0.0050, -0.0006]],

         [[-0.0079,  0.0035,  0.0032],
          [-0.0026, -0.0034, -0.0005],
          [ 0.0022,  0.0095, -0.0014]],

         [[-0.0030, -0.0045,  0.0031],
          [ 0.0038,  0.0023,  0.0054],
          [ 0.0094,  0.0156, -0.0074]]]], grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[1.7393e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))

## Output Quantization

Let's now look at would have happened if we instead enabled output quantization:

In [21]:
from brevitas.quant.scaled_int import Int8ActPerTensorFloat

output_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=False, 
    output_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
out_tensor = output_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor

QuantTensor(value=tensor([[[[ 0.1127,  0.7247,  0.4509],
          [ 0.0805, -0.5153, -0.7569],
          [-0.0483, -0.9662, -0.1610]],

         [[ 0.3060,  0.7730, -0.3704],
          [ 0.1771,  0.5153, -0.0805],
          [ 0.6925, -0.4831,  1.1272]],

         [[ 2.0451, -0.8535,  0.1932],
          [ 0.1610,  1.2239, -0.4670],
          [-0.7086,  0.6441, -0.9984]]]], grad_fn=<MulBackward0>), scale=tensor(0.0161, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))

In [22]:
out_tensor.is_valid

True

We can see again that the output is a valid `QuantTensor`. However, what happened internally is quite different from before.  
Previously, we computed the convolution between two quantized tensors, and got a quantized tensor as output.  
In this case instead, we compute the convolution between a quantized and an unquantized tensor, we take its unquantized output and we quantize it.  
The difference is obvious once we look at the output `bit_width`. In the previous case, we had that the `bit_width` reflected the size of the output accumulator. In this case instead, we have `bit_width=tensor(8.)`, which is what we expected since `output_quant` had been set to an *Int8* quantizer.

## Bias Quantization

There is an important scenario where the various options we just saw make a practical difference, and it's quantization of *bias*. In many contexts, such as in the ONNX standard opset and in FINN, bias is assumed to be quantized with scale factor equal to *input scale * weight scale*, which means that we need a valid quantized input somehow. A predefined bias quantizer that reflects that assumption is `brevitas.quant.scaled_int.Int8Bias`. If we simply tried to set it to a `QuantConv2d` without any sort of input quantization, we would get an error:

In [23]:
from brevitas.quant.scaled_int import Int8Bias

bias_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
    bias_quant=Int8Bias, return_quant_tensor=True)
bias_quant_conv(torch.randn(1, 2, 5, 5))

RuntimeError: Input scale required

We can solve the issue by passing in a valid `QuantTensor`, e.g. the `quant_tensor_input`  we defined above:

In [24]:
bias_quant_conv(quant_tensor_input)

QuantTensor(value=tensor([[[[-0.0040, -0.0009,  0.0006],
          [ 0.0053, -0.0011,  0.0011],
          [ 0.0043,  0.0068,  0.0016]],

         [[ 0.0010,  0.0020,  0.0037],
          [-0.0032, -0.0055,  0.0034],
          [ 0.0078, -0.0071,  0.0038]],

         [[ 0.0012, -0.0061,  0.0010],
          [ 0.0008,  0.0001, -0.0060],
          [ 0.0030, -0.0052,  0.0061]]]], grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[1.8108e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))

Or by enabling input quantization and then passing in a float a `torch.Tensor` or a `QuantTensor`:

In [25]:
input_bias_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
    input_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
input_bias_quant_conv(torch.randn(1, 2, 5, 5))

QuantTensor(value=tensor([[[[-1.1932, -0.3228, -0.4671],
          [ 0.0132,  0.0988,  0.2729],
          [-0.8529,  0.4697, -0.5951]],

         [[-0.6967,  0.4200,  0.3516],
          [ 0.1296, -0.1609, -0.0758],
          [ 1.2978,  0.0923, -0.1931]],

         [[-0.6142,  0.8017, -0.1383],
          [-0.7863,  0.1125, -0.2210],
          [ 0.3591, -0.5942, -0.1165]]]], grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[4.9213e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))

In [26]:
input_bias_quant_conv(quant_tensor_input)

QuantTensor(value=tensor([[[[-5.1220e-03,  5.3110e-03, -1.7256e-03],
          [-3.5191e-03, -4.1349e-03,  1.1166e-03],
          [-2.6435e-03, -4.1546e-03,  7.3403e-03]],

         [[ 6.7945e-07, -6.7448e-03,  6.8607e-03],
          [ 2.3195e-04, -2.5737e-03, -5.3764e-03],
          [ 2.7155e-03,  5.5203e-03, -1.0607e-03]],

         [[ 4.6902e-03, -6.1470e-03,  5.9844e-03],
          [ 3.4126e-03, -6.6340e-03,  5.9087e-03],
          [-4.5733e-03,  3.8523e-03, -3.5290e-03]]]],
       grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[1.6992e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor(0.), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))

Notice how the output `bit_width=tensor(22.)`. This is because, in the worst-case, summing a *21-bit* integer (the size of the accumulator before bias is added) and an *8-bit* integer (the size of quantized bias) gives a *22-bit* integer.

Let's try now to enable output quantization instead of input quantization. That wouldn't have solved the problem with bias quantization, as output quantization is performed after bias is added:

In [27]:
output_bias_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
    output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
output_bias_quant_conv(torch.randn(1, 2, 5, 5))

RuntimeError: Input scale required

Not all scenarios require bias quantization to depend on the scale factor of the input. In those cases, biases can be quantized the same way weights are quantized, and have their own scale factor. In Brevitas, a predefined quantizer that reflects this other scenario is `Int8BiasPerTensorFloatInternalScaling`. In this case then a valid quantized input is not required:

In [28]:
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling

bias_internal_scale_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
    bias_quant=Int8BiasPerTensorFloatInternalScaling, return_quant_tensor=True)
bias_internal_scale_quant_conv(torch.randn(1, 2, 5, 5))

QuantTensor(value=tensor([[[[ 3.9541e-01,  1.0925e-01,  5.7339e-01],
          [ 2.8455e-01,  2.5993e-01, -6.4358e-01],
          [-1.5946e-02, -2.6222e-01,  1.1759e+00]],

         [[ 2.9224e-01, -7.9820e-01,  7.9893e-01],
          [-1.2683e-01, -5.2149e-01,  5.0705e-01],
          [-1.2161e+00, -3.3960e-01,  2.7555e-03]],

         [[-5.0872e-02, -2.2190e-01,  4.6538e-04],
          [-7.2407e-02,  6.3366e-04, -7.2510e-01],
          [-2.2703e-02, -7.2701e-01,  8.6453e-02]]]],
       grad_fn=<ThnnConv2DBackward>), scale=None, zero_point=None, bit_width=None, signed_t=None, training_t=tensor(True))

There are a couple of situations to be aware of concerning bias quantization that can lead to changes in the output `zero_point`.

Let's consider the scenario where we compute the convolution between a quantized input tensor and quantized weights. In the first case, we then add an *unquantized* bias on top of the output. In the second one, we add a bias quantized with its own scale factor, e.g. with the `Int8BiasPerTensorFloatInternalScaling` quantizer. In both cases, in order to make sure the output `QuantTensor` is valid (i.e. the affine quantization invariant is respected), the output `zero_point` becomes non-zero:

In [29]:
unquant_bias_input_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
    input_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
out_tensor = unquant_bias_input_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor

QuantTensor(value=tensor([[[[-0.2864,  0.0194, -0.5835],
          [ 0.0089, -0.4669,  0.1854],
          [-0.6356, -0.1184, -0.4888]],

         [[-0.4935,  0.1282, -0.0513],
          [ 0.2893, -0.4362, -0.8152],
          [ 0.4719, -0.1783,  0.3152]],

         [[ 0.1789, -0.2934,  0.2483],
          [ 0.3184, -0.5840, -0.3555],
          [-0.5761, -0.1768, -0.8496]]]], grad_fn=<ThnnConv2DBackward>), scale=tensor([[[[3.5363e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor([[[[6285.0532]],

         [[ 936.8170]],

         [[4303.7090]]]], grad_fn=<DivBackward0>), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))

In [30]:
out_tensor.is_valid

True

Finally, an important point about `QuantTensor`. With the exception of learned bit-width (which will be the subject of a separate tutorial) and some of the bias quantization scenarios we have just seen, usually returing a `QuantTensor` is not necessary and can create extra complexity. This is why currently `return_quant_tensor` defaults to `False`. We can easily see it in an example:

In [31]:
bias_input_quant_conv = QuantConv2d(
    in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
    input_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias)
bias_input_quant_conv(torch.randn(1, 2, 5, 5))

tensor([[[[ 1.0190,  1.2963,  0.0597],
          [-1.0870, -0.0248,  0.6649],
          [-0.3103,  0.6573, -0.3369]],

         [[ 0.6602, -0.6170, -0.6805],
          [-0.4438,  0.9469, -0.5658],
          [ 0.0269,  1.0239, -0.0896]],

         [[ 1.5136,  0.8112,  0.7560],
          [-0.4665,  0.2027, -1.1200],
          [-1.0580,  1.2795, -0.0788]]]], grad_fn=<ThnnConv2DBackward>)

Altough not obvious, the output is actually implicitly quantized.