In [1]:
import inspect
from brevitas.nn import QuantLSTM
from IPython.display import Markdown, display

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))

source = inspect.getsource(QuantLSTM.__init__)
pretty_print_source(source)

  from .autonotebook import tqdm as notebook_tqdm


```python
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            num_layers: int = 1,
            bias: Optional[bool] = True,
            batch_first: bool = False,
            bidirectional: bool = False,
            weight_quant=Int8WeightPerTensorFloat,
            bias_quant=Int32Bias,
            io_quant=Int8ActPerTensorFloat,
            gate_acc_quant=Int8ActPerTensorFloat,
            sigmoid_quant=Uint8ActPerTensorFloat,
            tanh_quant=Int8ActPerTensorFloat,
            cell_state_quant=Int8ActPerTensorFloat,
            coupled_input_forget_gates: bool = False,
            cat_output_cell_states=True,
            shared_input_hidden_weights=False,
            shared_intra_layer_weight_quant=False,
            shared_intra_layer_gate_acc_quant=False,
            shared_cell_state_quant=True,
            return_quant_tensor: bool = False,
            device: Optional[torch.device] = None,
            dtype: Optional[torch.dtype] = None,
            **kwargs):
        super(QuantLSTM, self).__init__(
            layer_impl=_QuantLSTMLayer,
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=bias,
            batch_first=batch_first,
            bidirectional=bidirectional,
            weight_quant=weight_quant,
            bias_quant=bias_quant,
            io_quant=io_quant,
            gate_acc_quant=gate_acc_quant,
            sigmoid_quant=sigmoid_quant,
            tanh_quant=tanh_quant,
            cell_state_quant=cell_state_quant,
            cifg=coupled_input_forget_gates,
            shared_input_hidden_weights=shared_input_hidden_weights,
            shared_intra_layer_weight_quant=shared_intra_layer_weight_quant,
            shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,
            shared_cell_state_quant=shared_cell_state_quant,
            return_quant_tensor=return_quant_tensor,
            dtype=dtype,
            device=device,
            **kwargs)
        if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:
            raise RuntimeError("Concatenating cell states requires shared cell quantizers.")
        self.cat_output_cell_states = cat_output_cell_states

```

In [1]:
import time
from IPython.display import IFrame

def show_netron(model_path, port):
    try:
        import netron
        time.sleep(3.)
        netron.start(model_path, address=("localhost", port), browse=False)
        return IFrame(src=f"http://localhost:{port}/", width="100%", height=400)
    except:
        pass

In [2]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only = QuantLSTM(input_size=10, hidden_size=20, weight_bit_width=4, io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
show_netron(export_path, 8080)

Stopping http://localhost:8080
Serving 'quant_lstm_weight_only_4b.onnx' at http://localhost:8080


In [6]:
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32)  # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})

[0;93m2024-06-15 16:16:35.665527411 [W:onnxruntime:, graph.cc:1296 Graph] Initializer onnx::LSTM_103 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.[m
[0;93m2024-06-15 16:16:35.665546285 [W:onnxruntime:, graph.cc:1296 Graph] Initializer onnx::Concat_104 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.[m


In [7]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_cifg = QuantLSTM(
    input_size=10, hidden_size=20, coupled_input_forget_gates=True, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_cifg_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_cifg, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)


In [8]:
show_netron(export_path, 8082)

Serving 'quant_lstm_weight_only_cifg_4b.onnx' at http://localhost:8082


In [2]:
from brevitas.quant.scaled_int import Int8ActPerTensorFloat, Int8WeightPerTensorFloat
Int8ActPerTensorFloat.tensor_quant


  from .autonotebook import tqdm as notebook_tqdm


RescalingIntQuant(
  (int_quant): IntQuant(
    (float_to_int_impl): RoundSte()
    (tensor_clamp_impl): TensorClamp()
    (delay_wrapper): DelayWrapper(
      (delay_impl): _NoDelay()
    )
  )
  (scaling_impl): ParameterFromRuntimeStatsScaling(
    (stats_input_view_shape_impl): OverTensorView()
    (stats): _Stats(
      (stats_impl): AbsPercentile()
    )
    (restrict_scaling): _RestrictValue(
      (restrict_value_impl): FloatRestrictValue()
    )
    (clamp_scaling): _ClampValue(
      (clamp_min_ste): ScalarClampMinSte()
    )
    (restrict_inplace_preprocess): Identity()
    (restrict_preprocess): Identity()
  )
  (int_scaling_impl): IntScaling()
  (zero_point_impl): ZeroZeroPoint(
    (zero_point): StatelessBuffer()
  )
  (msb_clamp_bit_width_impl): BitWidthConst(
    (bit_width): StatelessBuffer()
  )
)

In [3]:
Int8WeightPerTensorFloat.float_to_int_impl

RoundSte()

In [4]:
Int8WeightPerTensorFloat.int_scaling_impl

IntScaling()

In [5]:
Int8WeightPerTensorFloat.scaling_impl_type

<ScalingImplType.STATS: 'STATS'>

In [6]:
Int8WeightPerTensorFloat.float_to_int_impl_type

<FloatToIntImplType.ROUND: 'ROUND'>

In [7]:
Int8WeightPerTensorFloat.int_quant

IntQuant(
  (float_to_int_impl): RoundSte()
  (tensor_clamp_impl): TensorClampSte()
  (delay_wrapper): DelayWrapper(
    (delay_impl): _NoDelay()
  )
)