# Quantized RNNs and LSTMs

With version 0.8, Brevitas introduces support for quantized recurrent layers through `QuantRNN` and `QuantLSTM`. As with other Brevitas quantized layers, `QuantRNN` and `QuantLSTM` can be used as drop-in replacement for their floating-point variants, but they also go further and support some additional structural recurrent options not found in upstream PyTorch. Similarly to other quantized layers, both `QuantRNN` and `QuantLSTM` can take in different quantizers for different tensors involved in their computation.

## QuantRNN

We start by looking at `QuantRNN`:

In [1]:
import inspect
from brevitas.nn import QuantRNN
from IPython.display import Markdown, display
import torch
torch.manual_seed(0)

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))
    
source = inspect.getsource(QuantRNN.__init__)  
pretty_print_source(source)

```python
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            num_layers: int = 1,
            nonlinearity: str = 'tanh',
            bias: bool = True,
            batch_first: bool = False,
            bidirectional: bool = False,
            weight_quant = Int8WeightPerTensorFloat,
            bias_quant = Int32Bias,
            io_quant = Int8ActPerTensorFloat,
            gate_acc_quant = Int8ActPerTensorFloat,
            shared_input_hidden_weights = False,
            return_quant_tensor: bool = False,
            **kwargs):
        super(QuantRNN, self).__init__(
            layer_impl=_QuantRNNLayer,
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            nonlinearity=nonlinearity,
            bias=bias,
            batch_first=batch_first,
            bidirectional=bidirectional,
            weight_quant=weight_quant,
            bias_quant=bias_quant,
            io_quant=io_quant,
            gate_acc_quant=gate_acc_quant,
            shared_input_hidden_weights=shared_input_hidden_weights,
            return_quant_tensor=return_quant_tensor,
            **kwargs)

```

`QuantRNN` supports all arguments of `torch.nn.RNN`, plus it exposes four different quantizers: `weight_quant` controls quantization of the weight tensor, `bias_quant` controls quantization of the bias, `io_quant` controls quantization of the input/output, and `gate_acc_quant` controls quantization of the output of the gate, before the nonlinearity is applied. 

Compared to other layers like `QuantLinear`, a couple of things can be observed. First, input and output quantization are fused together into `io_quant`. This is because of the recurrent structure of RNN layers, where the output is fed back as input. Second, all quantizers are already set by default. This is different from a layer like `QuantLinear`, where only `weight_quant` has a default quantizer.

As with `torch.nn.RNN`, `QuantRNN` defines a stack of potentially multiple layers, controlled by setting `num_layers`, that can be set to bidirectional with `bidirectional=True`. Internally, `QuantRNN` is organized into a two level nesting of `ModuleList`, one for the different layer(s), and one for the direction(s):

In [2]:
def rnn_sublayer(module, sublayer_number, right_to_left_direction):
    return module.layers[sublayer_number][1 if right_to_left_direction else 0]

quant_rnn = QuantRNN(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
quant_rnn_0_left_to_right = rnn_sublayer(quant_rnn, sublayer_number=0, right_to_left_direction=False)
quant_rnn_0_right_to_left = rnn_sublayer(quant_rnn, sublayer_number=0, right_to_left_direction=True)
quant_rnn_1_left_to_right = rnn_sublayer(quant_rnn, sublayer_number=1, right_to_left_direction=False)
quant_rnn_1_right_to_left = rnn_sublayer(quant_rnn, sublayer_number=1, right_to_left_direction=True)

  warn('Keyword arguments are being passed but they not being used.')


Setting `num_layers > 1` and/or `bidirectional=True` has different implications on different quantizers. For `weight_quant`, `gate_acc_quant` and `bias_quant`, the same quantizer *definition* is shared among different layers/directions, but each layer/direction is allocated its own instance of the quantizer. 

In [3]:
quant_rnn_0_left_to_right.gate_params.input_weight.weight_quant is quant_rnn_1_right_to_left.gate_params.input_weight.weight_quant

False

In [4]:
quant_rnn_0_left_to_right.cell.gate_acc_quant is quant_rnn_1_right_to_left.cell.gate_acc_quant

False

In [5]:
quant_rnn_0_left_to_right.gate_params.bias_quant is quant_rnn_1_right_to_left.gate_params.bias_quant

False

Conversely, for `io_quant` the same *instance* is gonna be shared among all layers and directions. This is to make sure that input/output tensors that are internally concatenated together share the same quantization scale/zero-point/bitwidth. 

In [6]:
quant_rnn_0_left_to_right.io_quant is quant_rnn_1_right_to_left.io_quant

True

Finally, `QuantRNN` supports an additional flag, `shared_input_hidden_weights`. This allows, whenever `bidirectional=True`, to share the input-to-hidden weights among the two directions, an optimization introduced first by DeepSpeech back in the day to save on the number of parameters, with minimal impact on the quality of results.

In [7]:
from brevitas.nn import QuantRNN

def count_weights(model):
    return sum(p.numel() for n, p in model.named_parameters() if 'weight' in n)

quant_rnn_single_direction = QuantRNN(input_size=10, hidden_size=20, bidirectional=False, shared_input_hidden_weights=False)
quant_rnn_bidirectional = QuantRNN(input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=False)
quant_rnn_bidirectional_shared_input_hidden = QuantRNN(input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=True)

print(f"Number of weights for single direction QuantRNN: {count_weights(quant_rnn_single_direction)}")
print(f"Number of weights for bidirectional QuantRNN: {count_weights(quant_rnn_bidirectional)}")
print(f"Number of weights for bidirectional QuantRNN with shared input-hidden weights: {count_weights(quant_rnn_bidirectional_shared_input_hidden)}")



Number of weights for single direction QuantRNN: 600
Number of weights for bidirectional QuantRNN: 1200
Number of weights for bidirectional QuantRNN with shared input-hidden weights: 1000


As with other Brevitas layers, it's possible to directly modify a quantizer by passing keyword arguments with a matching prefix. For example, to set 4b per-channel weights and 6b io quantization:

In [8]:
quant_rnn_4b = QuantRNN(input_size=10, hidden_size=20, weight_bit_width=4, weight_scaling_per_output_channel=True, io_bit_width=6)
quant_rnn_4b_0_left_to_right = rnn_sublayer(quant_rnn_4b, sublayer_number=0, right_to_left_direction=False)

input_hidden_weight = quant_rnn_4b_0_left_to_right.gate_params.input_weight.quant_weight()
hidden_hidden_weight = quant_rnn_4b_0_left_to_right.gate_params.hidden_weight.quant_weight()

print(f"Input-hidden weight bit-width: {input_hidden_weight.bit_width}")
print(f"Hidden-hidden weight bit-width: {hidden_hidden_weight.bit_width}")
print(f"I/O quant bit-width: {quant_rnn_4b_0_left_to_right.io_quant.bit_width()}")
print(f"Input-hidden weight scale: {input_hidden_weight.scale}")
print(f"Hidden-hidden weight scale: {hidden_hidden_weight.scale}")

Input-hidden weight bit-width: 4.0
Hidden-hidden weight bit-width: 4.0
I/O quant bit-width: 6.0
Input-hidden weight scale: tensor([[0.0316],
        [0.0317],
        [0.0319],
        [0.0318],
        [0.0314],
        [0.0298],
        [0.0317],
        [0.0285],
        [0.0306],
        [0.0312],
        [0.0318],
        [0.0315],
        [0.0298],
        [0.0314],
        [0.0293],
        [0.0310],
        [0.0306],
        [0.0310],
        [0.0309],
        [0.0317]], grad_fn=<DivBackward0>)
Hidden-hidden weight scale: tensor([[0.0316],
        [0.0317],
        [0.0319],
        [0.0318],
        [0.0314],
        [0.0298],
        [0.0317],
        [0.0285],
        [0.0306],
        [0.0312],
        [0.0318],
        [0.0315],
        [0.0298],
        [0.0314],
        [0.0293],
        [0.0310],
        [0.0306],
        [0.0310],
        [0.0309],
        [0.0317]], grad_fn=<DivBackward0>)


`QuantRNN` follows the same `forward` interface of `torch.nn.RNN`, with a couple of exceptions. Packed variable length inputs are currently not supported, and unbatched inputs are not supported. 
Other than that, everything else is the same. 

Inputs are expected to have shape `(batch, sequence, features)` for `batch_first=False`, or `(sequence, batch, features)` for `batch_first=True`. The layer returns a tuple with `(outputs, hidden_states)`, where `outputs` has shape `(sequence, batch, hidden_size * num_directions)` with `num_directions=2` when `bidirectional=True`, for `batch_first=False`, or `(batch, sequence, hidden_size * num_directions)` for `batch_first=True`, while `hidden_states` has shape `(num_directions * num_layers, batch, hidden_size)`.

In [9]:
import torch
from brevitas.nn import QuantRNN

quant_rnn = QuantRNN(input_size=10, hidden_size=20, batch_first=True)
outputs, hidden_states = quant_rnn(torch.randn(2, 5, 10))
print(f"Output size: {outputs.shape}")
print(f"Hidden states size: {hidden_states.shape}")

Output size: torch.Size([2, 5, 20])
Hidden states size: torch.Size([1, 2, 20])


As with other quantized layers, it's possible to return a `QuantTensor` with `return_quant_tensor=True`. As a reminder, a `QuantTensor` is just a data structure that captures the quantization metadata associated with a quantized tensor:

In [10]:
import torch
from brevitas.nn import QuantRNN

quant_rnn = QuantRNN(input_size=10, hidden_size=20, batch_first=True, return_quant_tensor=True)
quant_rnn(torch.randn(2, 5, 10))

  return torch.cat(outputs, dim=seq_dim)


(QuantTensor(value=tensor([[[-0.4458, -0.1651, -0.7045, -0.5889, -0.2532, -0.0330, -0.1651,
            0.1706,  0.1376,  0.4348,  0.5834, -0.3577, -0.2807,  0.1046,
            0.2532,  0.2807,  0.2532, -0.4293,  0.1376, -0.1486],
          [-0.1569,  0.3530, -0.6995, -0.0458, -0.5295, -0.3007, -0.7257,
            0.2877, -0.1308,  0.6603,  0.0196, -0.8237,  0.0065, -0.4380,
           -0.2615,  0.3138, -0.0850,  0.0065,  0.0458, -0.1961],
          [ 0.1929, -0.5981, -0.2508, -0.2251, -0.5917,  0.2251,  0.0257,
            0.2508, -0.3023,  0.2830,  0.3344, -0.4309, -0.0836,  0.2701,
            0.3666, -0.1351,  0.1736, -0.0257,  0.1286, -0.6174],
          [ 0.4682, -0.1804,  0.2780,  0.4974,  0.4389, -0.0585, -0.6242,
           -0.0098,  0.2341,  0.3511, -0.2926, -0.4925,  0.1414, -0.4633,
           -0.0683,  0.2633,  0.3804,  0.3024,  0.1951,  0.1707],
          [-0.0852,  0.0965, -0.4656, -0.3180, -0.3464, -0.2782, -0.1931,
           -0.6360, -0.3180, -0.3293,  0.7211,  0.43

Similarly, a `QuantTensor` can be passed in as input. However, whenever `io_quant` is set (which it is by default), the input is gonna be re-quantized:

In [11]:
from brevitas.nn import QuantIdentity

quant_identity = QuantIdentity(return_quant_tensor=True)
quant_rnn(quant_identity(torch.randn(2, 5, 10)))

(QuantTensor(value=tensor([[[ 0.1760,  0.2670, -0.1214, -0.3702,  0.3884,  0.4127,  0.0243,
            0.0425, -0.2246, -0.0910, -0.2670,  0.4734,  0.0971, -0.3824,
            0.1396,  0.6858,  0.0061,  0.3702,  0.1275,  0.5037],
          [ 0.2831,  0.0566, -0.2831, -0.2661, -0.0793,  0.3511, -0.4926,
            0.0510, -0.6455,  0.7191, -0.1812, -0.6172,  0.1529,  0.4077,
           -0.7078, -0.0453, -0.0963,  0.4926, -0.4983, -0.4077],
          [ 0.0000, -0.3977,  0.0947,  0.1894, -0.3725, -0.2589, -0.3914,
            0.3409, -0.0063,  0.2652, -0.5177, -0.4230, -0.0821, -0.0631,
            0.0505, -0.0189,  0.0253, -0.1578, -0.4988,  0.5556],
          [ 0.4809,  0.8144, -0.6925,  0.4360,  0.0256, -0.4360, -0.5130,
            0.2501, -0.1347,  0.7631, -0.5386, -0.2437,  0.4296, -0.1988,
           -0.7246, -0.1154, -0.2437,  0.3655,  0.0641,  0.3142],
          [ 0.0706, -0.0192, -0.7185, -0.8211, -0.5709,  0.1155,  0.4683,
            0.3400, -0.3015,  0.3528,  0.3143, -0.11

As with `torch.nn.RNN`, by default the initial hidden state is initialized to 0, but a custom hidden state of shape `(num_directions * num_layers, batch, hidden_size)` can be passed in:

In [12]:
quant_rnn(torch.randn(2, 5, 10), torch.randn(1, 2, 20))

(QuantTensor(value=tensor([[[-0.1984,  0.2499, -0.1102,  0.2499, -0.0955, -0.4630, -0.8672,
            0.1911, -0.4851,  0.8085,  0.6982, -0.5806,  0.0000, -0.4189,
           -0.7423, -0.4851, -0.9260, -0.0147,  0.0514, -0.1984],
          [-0.2167,  0.5092, -0.3846,  0.0650,  0.6717, -0.2492, -0.0867,
            0.3142, -0.3900,  0.3521,  0.4767, -0.1137,  0.6879,  0.1733,
           -0.0596,  0.4279, -0.5471, -0.2762,  0.5904, -0.3737],
          [-0.1335, -0.0140, -0.2810, -0.5339, -0.5339,  0.0562,  0.7236,
           -0.1264, -0.0211, -0.3021, -0.1124,  0.4777,  0.3793,  0.2388,
           -0.0702,  0.4847, -0.4988,  0.7236,  0.5901, -0.4847],
          [ 0.3340, -0.5225, -0.1242,  0.1499,  0.3083, -0.1756, -0.1713,
            0.0000,  0.3512, -0.3041,  0.3126, -0.5482,  0.4882,  0.1028,
           -0.4796,  0.1028, -0.2527, -0.3640,  0.1713,  0.0471],
          [-0.4438, -0.2686, -0.3095, -0.2978, -0.0993,  0.0584,  0.4846,
           -0.0526,  0.3737, -0.4496,  0.1109,  0.74

As with other Brevitas layers, `QuantRNN` can be initialized from a pretrained floating-point `torch.nn.RNN`. For the purpose of this tutorial, can simulate it from an untrained `torch.nn.RNN`. As for other quantized layers, setting `brevitas.config.IGNORE_MISSING_KEYS` might be necessary (depending on which quantizers are set). With the default quantizers, an error on activation scale keys would be triggered, so we set it to true:

In [13]:
from torch.nn import RNN
from brevitas.nn import QuantRNN
from brevitas import config

config.IGNORE_MISSING_KEYS = True

float_rnn = RNN(input_size=10, hidden_size=20)
quant_rnn = QuantRNN(input_size=10, hidden_size=20)
quant_rnn.load_state_dict(float_rnn.state_dict())

<All keys matched successfully>

Similar to other quantized layers, quantization on a certain tensor can be disabled by setting a quantizer to `None`. Setting all quantizers to `None` recovers the same behaviour as the floating-point variant:

In [14]:
from torch.nn import RNN
from brevitas.nn import QuantRNN
from brevitas import config
ATOL = 1e-6

config.IGNORE_MISSING_KEYS = True
torch.manual_seed(123456)

float_rnn = RNN(input_size=10, hidden_size=20)
quant_rnn = QuantRNN(input_size=10, hidden_size=20, weight_quant=None, io_quant=None, gate_acc_quant=None, bias_quant=None)

# Set both layers to the same state_dict
quant_rnn.load_state_dict(float_rnn.state_dict())

# Generate random input
inp = torch.randn(5, 2, 10)
# Check outputs are the same
assert torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL)
# Check hidden states are the same
assert torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL)

As with other quantized layers, we can leverage other prebuilt quantizers too. For example, to perform binary weight quantization:

In [15]:
from brevitas.quant.binary import SignedBinaryWeightPerTensorConst

binary_rnn = QuantRNN(input_size=10, hidden_size=20, weight_quant=SignedBinaryWeightPerTensorConst)
binary_rnn(torch.randn(5, 2, 10))

(tensor([[[-0.3684, -0.0946, -0.4480,  0.0050,  0.1543,  0.6322,  0.1643,
            0.1693,  0.2937,  0.5227,  0.2290, -0.3534, -0.3883,  0.4331,
            0.0000,  0.1693, -0.4331,  0.3634, -0.0050,  0.1941],
          [-0.2240, -0.0199, -0.3534,  0.0946,  0.3485,  0.3534,  0.1941,
            0.1643,  0.1145,  0.4082,  0.2987, -0.0647, -0.0946,  0.1543,
            0.1145, -0.0498,  0.0647,  0.1493,  0.0299, -0.1195]],
 
         [[ 0.0776, -0.0776, -0.5670,  0.4178, -0.0239,  0.4476,  0.2029,
           -0.0836,  0.3521,  0.7042,  0.6326,  0.4058, -0.4118, -0.0477,
           -0.2387, -0.0179, -0.4416, -0.4237, -0.3282, -0.1074],
          [-0.2626,  0.3581,  0.2328, -0.2268, -0.2686, -0.3103,  0.4536,
            0.3461,  0.3103,  0.3163,  0.3282, -0.3163, -0.7639,  0.0179,
            0.0060,  0.0776, -0.5849, -0.5252,  0.1790,  0.2984]],
 
         [[-0.5411,  0.3147,  0.6184, -0.3037, -0.1877, -0.3755,  0.1767,
           -0.1767, -0.1491, -0.1049,  0.2871, -0.0552, -0.0883,

## QuantLSTM

We now look at `QuantLSTM`:

In [16]:
import inspect
from brevitas.nn import QuantLSTM
from IPython.display import Markdown, display

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))
    
source = inspect.getsource(QuantLSTM.__init__)  
pretty_print_source(source)

```python
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            num_layers: int = 1,
            bias: bool = True,
            batch_first: bool = False,
            bidirectional: bool = False,
            weight_quant = Int8WeightPerTensorFloat,
            bias_quant = Int32Bias,
            io_quant = Int8ActPerTensorFloat,
            gate_acc_quant = Int8ActPerTensorFloat,
            sigmoid_quant = Uint8ActPerTensorFloat,
            tanh_quant = Int8ActPerTensorFloat,
            cell_state_quant = Int8ActPerTensorFloat,
            coupled_input_forget_gates: bool = False,
            cat_output_cell_states = True,
            shared_input_hidden_weights = False,
            shared_intra_layer_weight_quant = False,
            shared_intra_layer_gate_acc_quant = False,
            shared_cell_state_quant = True,
            return_quant_tensor: bool = False,
            **kwargs):
        super(QuantLSTM, self).__init__(
            layer_impl=_QuantLSTMLayer,
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=bias,
            batch_first=batch_first,
            bidirectional=bidirectional,
            weight_quant=weight_quant,
            bias_quant=bias_quant,
            io_quant=io_quant,
            gate_acc_quant=gate_acc_quant,
            sigmoid_quant=sigmoid_quant,
            tanh_quant=tanh_quant,
            cell_state_quant=cell_state_quant,
            cifg=coupled_input_forget_gates,
            shared_input_hidden_weights=shared_input_hidden_weights,
            shared_intra_layer_weight_quant=shared_intra_layer_weight_quant,
            shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,
            shared_cell_state_quant=shared_cell_state_quant,
            return_quant_tensor=return_quant_tensor,
            **kwargs)
        if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:
            raise RuntimeError("Concatenating cell states requires shared cell quantizers.")
        self.cat_output_cell_states = cat_output_cell_states

```

As with `QuantRNN`, `QuantLSTM` supports all options of `torch.nn.LSTM`. Everything said so far on `QuantRNN` applies to `QuantLSTM` too, but there a bunch of things more to be aware of.

`QuantLSTM` accepts a few more quantizers: `sigmoid_quant`, `tanh_quant` and `cell_state_quant`. As with `QuantRNN`, setting `bidirectional=True` and/or `num_layers > 1` triggers sharing the instance of certain quantizers, but not others. In particular `io_quant` is shared among all layers and directions, as it was the case for `QuantRNN`. `cell_state_quant` is shared by default, but setting `shared_cell_state_quant=False` can disable that. However, that requires setting `cat_output_cell_states=False`, as otherwise we would find ourselves with a concenation of cell states that have been quantized with different quantizers, which is considered illegal in Brevitas.

LSTMs have four gates, each with its input-hidden and hidden-hidden weights. Brevitas takes in one `weight_quant` definition, but then four different instances of the weight quantizer are instantiated, and each gate is quantized differently, meaning it can have its own scale and zero-point. To force sharing the same weight quantizer across all gates, `QuantLSTM` supports setting `shared_intra_layer_weight_quant=True`. The same reasoning applies to the quantization of the output of each gate, before the activation functions, which is controlled by the `gate_acc_quant` quantizer. To force the same quantizer instance to be shared, `shared_intra_layer_gate_acc_quant=True` can be set. Different sigmoid and tanh functions instead are always allocated different quantizer instances.

Finally, `QuantLSTM` also supports the coupled input-forget gates (CIFG), where the forget gate is defined as `forget_gate = 1 - input_gate`, by setting `coupled_input_forget_gates=True`. This is an optimization to save on some compute and number of parameters, and is orthogonal to all other settings, such as `shared_input_hidden_weights`.

## Just-in-time compilation

Custom recurrent layer can be quite slow at training time. With quantization added in, it only gets worse. To mitigate the issue, both `QuantRNN` and `QuantLSTM` support jit compilation. Setting the env variable `BREVITAS_JIT=1` triggers end-to-end compilation of the quantized recurrent cell through PyTorch TorchScript compiler.

## Calibration

As of version 0.8 of Brevitas, `QuantRNN` and `QuantLSTM` don't support quantized activations calibration through `calibration_mode `nor bias correction through `bias_correction_mode`. This will be added in a future version.

## Export

As of Brevitas 0.8, export of quantized recurrent layers is still a work in progress. As a proof of concept, there is partial support only for export of `QuantLSTM` to ONNX QCDQ, a way to represented quantization in ONNX only with standard ops (QuantizeLinear->Clip->DequantizeLinear), and to QONNX, a custom set of quantized operators introduced by Brevitas on top of ONNX. Two use cases are supported: (1) only `weight_quant` is set, supported by both QCDQ and QONNX, and (2) all quantizers are set, supported only by QONNX. In both cases, `bidirectional=True` and `num_layers > 1` are supported. We first define an utility function to visualize the network through netron, which requires `pip install netron`.

In [17]:
import time
from IPython.display import IFrame

def show_netron(model_path, port):
    try:
        import netron
        time.sleep(3.)
        netron.start(model_path, address=("localhost", port), browse=False)
        return IFrame(src=f"http://localhost:{port}/", width="100%", height=400)
    except:
        pass

### QuantLSTM weight-only quantization export

For use case (1), we leverage export to ONNX QCDQ. Qeight quantization is represented with QCDQ nodes, while the standard ONNX `LSTM` operator is adopted for the recurrent cell. With this approach, we can represent any weight bit width >= 2. Opset 14 is required. For the purpose of this 1 layer, 1 direction example we keep the default `weight_quant` set, we add `weight_bit_width=4`, while we disable the other quantizers:

In [18]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only = QuantLSTM(input_size=10, hidden_size=20, weight_bit_width=4, io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)


In [19]:
show_netron(export_path, 8080)

Serving 'quant_lstm_weight_only_4b.onnx' at http://localhost:8080


Note that the model can then be accelerated in `onnxruntime`:

In [19]:
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32)  # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})

CIFG is also supported in a way that follows the semantics of `onnxruntime`:

In [21]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_cifg = QuantLSTM(
    input_size=10, hidden_size=20, coupled_input_forget_gates=True, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_cifg_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_cifg, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)

In [22]:
show_netron(export_path, 8082)

Serving 'quant_lstm_weight_only_cifg_4b.onnx' at http://localhost:8082


As before we can run it with `onnxruntime`:

In [21]:
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32)  # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})

For the 2 layers, 2 directions use case:

In [24]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)

  warn('Keyword arguments are being passed but they not being used.')


In [25]:
show_netron(export_path, 8083)

Serving 'quant_lstm_weight_only_bidirectional_2_layers.onnx' at http://localhost:8083


Shared input-hidden weights are also supported:

In [26]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers_shared = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=True, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers_shared, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)

  warn('Keyword arguments are being passed but they not being used.')


In [27]:
show_netron(export_path, 8085)

Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx' at http://localhost:8085


We can observe how setting `shared_intra_layer_weight_quant=True` affects the network. Now, for each layer and for each direction within a layer, all weight quantizers share the same scale/zp/bit-width:

In [24]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)

  warn('Keyword arguments are being passed but they not being used.')


Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx' at http://localhost:8086


In [None]:
show_netron(export_path, 8086)

Alternatively, if we set both `shared_input_hidden_weights=True` and `shared_intra_layer_weight_quant=True`, the side effect is that all quantizers among both directions in a given layer are gonna have the same scale/zp/bit-width. 

In [25]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, 
    shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)

  warn('Keyword arguments are being passed but they not being used.')


Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx' at http://localhost:8087


In [None]:
show_netron(export_path, 8087)

### QuantLSTM full quantization export

For use case (2) we export to QONNX. Weight quantization is represented with `Quant` nodes, while a custom quantized LSTM operator `QuantLSTMCell` operator is generated for the recurrent cell. Note that currently `QuantLSTMCell` is not yet supported for execution in the `qonnx` library. In a future version of Brevitas, `QuantLSTMCell` will instead be lowered to a series of standard ops + `Quant` nodes. For the purpose example, we keep all quantizers at default:

In [26]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_qonnx

quant_lstm = QuantLSTM(input_size=10, hidden_size=20)
export_path = 'quant_lstm.onnx'
exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)

Serving 'quant_lstm.onnx' at http://localhost:8088


In [None]:
show_netron(export_path, 8088)

`QuantLSTMCell` takes the following series of inputs. 

- quant_input, 
- quant_hidden_state, 
- quant_cell_state,
- quant_weight_ii, 
- quant_weight_if,
- quant_weight_ic,
- quant_weight_io,
- quant_weight_hi, 
- quant_weight_hf,
- quant_weight_hc,
- quant_weight_ho,
- quant_bias_input,
- quant_bias_forget,
- quant_bias_cell,
- quant_bias_output,
- output_scale, 
- output_zero_point, 
- output_bit_width, 
- cell_state_scale, 
- cell_state_zero_point, 
- cell_state_bit_width, 
- input_acc_scale, 
- input_acc_zero_point,
- input_acc_bit_width, 
- forget_acc_scale, 
- forget_acc_zero_point, 
- forget_acc_bit_width, 
- cell_acc_scale, 
- cell_acc_zero_point, 
- cell_acc_bit_width, 
- output_acc_scale, 
- output_acc_zero_point, 
- output_acc_bit_width, 
- input_sigmoid_scale, 
- input_sigmoid_zero_point, 
- input_sigmoid_bit_width, 
- forget_sigmoid_scale, 
- forget_sigmoid_zero_point, 
- forget_sigmoid_bit_width, 
- cell_tanh_scale, 
- cell_tanh_zero_point, 
- cell_tanh_bit_width, 
- output_sigmoid_scale, 
- output_sigmoid_zero_point, 
- output_sigmoid_bit_width, 
- hidden_state_tanh_scale, 
- hidden_state_tanh_zero_point, 
- hidden_state_tanh_bit_width

All previous use cases illustrated for the weight-only quantization scenario are also supported.