# INT8 Quantization of **LSTMNetVIT** and ONNX export (opset 17)

This notebook walks through:

1. Re‑creating the `LSTMNetVIT` architecture in PyTorch and loading pretrained weights  
2. Exporting the model to **ONNX** with opset 17  
3. Quantising the exported model to **INT8** with **ONNX Runtime**  
4. Running a quick functional check to compare FP32 vs. INT8 outputs  

> **Replace** the placeholder paths (`PATH_TO_WEIGHTS`) and (optionally) supply your own calibration data or dataset loader if you want to run static quantisation. The default path below applies **dynamic** INT8 quantisation which does not need calibration.


## 0. Install / Upgrade dependencies

In [16]:
# If you are in a fresh environment, uncomment:
# !pip install --upgrade --quiet torch torchvision torchaudio onnx onnxruntime onnxruntime-tools

## 1. Imports & helper definitions

In [17]:
import torch, torch.nn as nn
from torch.nn.utils import spectral_norm
import onnx, onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType, quantize_static
from pathlib import Path
print('PyTorch', torch.__version__)
print('ONNX', onnx.__version__)
print('ONNX Runtime', ort.__version__)

PyTorch 2.6.0+cpu
ONNX 1.17.0
ONNX Runtime 1.19.2


### 1.1 Model architecture

In [18]:
# ----  Helper for original forward() ----
def refine_inputs(X):              # identity placeholder; keep if original util not available
    return X

# ---- Component modules (copied from original script) ----
class OverlapPatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size, stride, padding):
        super().__init__()
        self.cn1 = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,
                             stride=stride, padding=padding)
        self.layerNorm = nn.LayerNorm(out_channels)

    def forward(self, patches):
        x = self.cn1(patches)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.layerNorm(x)
        return x, H, W

class EfficientSelfAttention(nn.Module):
    def __init__(self, channels, reduction_ratio, num_heads):
        super().__init__()
        assert channels % num_heads == 0
        self.heads = num_heads
        self.cn1 = nn.Conv2d(channels, channels, kernel_size=reduction_ratio,
                             stride=reduction_ratio)
        self.ln1 = nn.LayerNorm(channels)
        self.keyValueExtractor = nn.Linear(channels, channels * 2)
        self.query = nn.Linear(channels, channels)
        self.smax = nn.Softmax(dim=-1)
        self.finalLayer = nn.Linear(channels, channels)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x1 = x.permute(0, 2, 1).reshape(B, C, H, W)
        x1 = self.cn1(x1).reshape(B, C, -1).permute(0, 2, 1)
        x1 = self.ln1(x1)
        k, v = self.keyValueExtractor(x1).reshape(B, -1, 2, self.heads, C // self.heads)                                           .permute(2, 0, 3, 1, 4)
        q = self.query(x).reshape(B, N, self.heads, C // self.heads)                         .permute(0, 2, 1, 3)
        dimHead = (C / self.heads) ** 0.5
        attn = self.smax(torch.matmul(q, k.transpose(-2, -1)) / dimHead)
        attn = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
        return self.finalLayer(attn)

class MixFFN(nn.Module):
    def __init__(self, channels, expansion_factor):
        super().__init__()
        expanded = channels * expansion_factor
        self.mlp1 = nn.Linear(channels, expanded)
        self.depthwise = nn.Conv2d(expanded, expanded, kernel_size=3,
                                   padding='same', groups=channels)
        self.gelu = nn.GELU()
        self.mlp2 = nn.Linear(expanded, channels)

    def forward(self, x, H, W):
        x = self.mlp1(x)
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.gelu(self.depthwise(x).flatten(2).transpose(1, 2))
        return self.mlp2(x)

class MixTransformerEncoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size, stride, padding,
                 n_layers, reduction_ratio, num_heads, expansion_factor):
        super().__init__()
        self.patchMerge = OverlapPatchMerging(in_channels, out_channels,
                                              patch_size, stride, padding)
        self._attn = nn.ModuleList(
            [EfficientSelfAttention(out_channels, reduction_ratio, num_heads)
             for _ in range(n_layers)])
        self._ffn = nn.ModuleList(
            [MixFFN(out_channels, expansion_factor) for _ in range(n_layers)])
        self._lNorm = nn.ModuleList([nn.LayerNorm(out_channels)
                                     for _ in range(n_layers)])

    def forward(self, x):
        B, C, H, W = x.shape
        x, H, W = self.patchMerge(x)
        for attn, ffn, ln in zip(self._attn, self._ffn, self._lNorm):
            x = ln(x + attn(x, H, W) + ffn(x, H, W))
        return x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

# ---- Full model ----
class LSTMNetVIT(nn.Module):
    """ViT + LSTM network (~3.56 M params)"""
    def __init__(self):
        super().__init__()
        self.encoder_blocks = nn.ModuleList([
            MixTransformerEncoderLayer(1, 32, patch_size=7, stride=4, padding=3,
                                       n_layers=2, reduction_ratio=8, num_heads=1, expansion_factor=8),
            MixTransformerEncoderLayer(32, 64, patch_size=3, stride=2, padding=1,
                                       n_layers=2, reduction_ratio=4, num_heads=2, expansion_factor=8)
        ])
        self.decoder = spectral_norm(nn.Linear(4608, 512))
        self.lstm = nn.LSTM(input_size=517, hidden_size=128, num_layers=3,
                            dropout=0.1)
        self.nn_fc2 = spectral_norm(nn.Linear(128, 3))
        self.up_sample = nn.Upsample(size=(16, 24), mode='bilinear',
                                     align_corners=True)
        self.pxShuffle = nn.PixelShuffle(upscale_factor=2)
        self.down_sample = nn.Conv2d(48, 12, 3, padding=1)

    def forward(self, X):
        X = refine_inputs(X)
        x = X[0]
        embeds = [x]
        for block in self.encoder_blocks:
            embeds.append(block(embeds[-1]))
        out = torch.cat([self.pxShuffle(embeds[2]),
                         self.up_sample(embeds[1])], dim=1)
        out = self.down_sample(out)
        out = self.decoder(out.flatten(1))
        out = torch.cat([out, X[1] / 10.0, X[2]], dim=1).float()
        if len(X) > 3:
            out, h = self.lstm(out, X[3])
        else:
            out, h = self.lstm(out)
        out = self.nn_fc2(out)
        return out, h


### 1.2 Load pretrained weights

In [19]:
model = LSTMNetVIT()
weights_path = Path('ViTLSTM_model.pth')  # <-- change me
if weights_path.exists():
    state_dict = torch.load(weights_path, map_location='cpu')
    model.load_state_dict(state_dict)
else:
    print('‼️ Pre‑trained weights not found – using random init')
model.eval()

LSTMNetVIT(
  (encoder_blocks): ModuleList(
    (0): MixTransformerEncoderLayer(
      (patchMerge): OverlapPatchMerging(
        (cn1): Conv2d(1, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
        (layerNorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (_attn): ModuleList(
        (0-1): 2 x EfficientSelfAttention(
          (cn1): Conv2d(32, 32, kernel_size=(8, 8), stride=(8, 8))
          (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (keyValueExtractor): Linear(in_features=32, out_features=64, bias=True)
          (query): Linear(in_features=32, out_features=32, bias=True)
          (smax): Softmax(dim=-1)
          (finalLayer): Linear(in_features=32, out_features=32, bias=True)
        )
      )
      (_ffn): ModuleList(
        (0-1): 2 x MixFFN(
          (mlp1): Linear(in_features=32, out_features=256, bias=True)
          (depthwise): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=32)
       

### 1.3 Create a dummy input for export

In [20]:
batch_size = 1
x0 = torch.randn(batch_size, 1, 60, 90)   # resized image tensor (60x90)
x1 = torch.randn(batch_size, 1)           # scalar feature (desired velocity)
x2 = torch.randn(batch_size, 4)           # quad orientation feature
dummy_input = [x0, x1, x2]

## 2. Export to ONNX (opset 17)

In [21]:
onnx_fp32 = 'lstmnetvit_fp32.onnx'
torch.onnx.export(
    model,
    dummy_input,
    onnx_fp32,
    opset_version=17,
    input_names=['x'],
    output_names=['out', 'hidden'],
    #dynamic_axes={'x0': {0: 'B'}, 'x1': {0: 'B'}, 'x2': {0: 'B'},
    #              'out': {0: 'B'}, 'hidden': {1: 'B'}}
    )
print('Exported:', onnx_fp32)
onnx.checker.check_model(onnx.load(onnx_fp32))


  k, v = self.keyValueExtractor(x1).reshape(B, -1, 2, self.heads, C // self.heads)                                           .permute(2, 0, 3, 1, 4)


Exported: lstmnetvit_fp32.onnx


## 3. Dynamic INT8 Quantisation with ONNX Runtime

In [24]:
onnx_int8 = 'lstmnetvit_int8.onnx'
quantize_dynamic(
    onnx_fp32,
    onnx_int8,
    weight_type=QuantType.QInt8,   # use signed INT8
)
print('INT8 model saved to', onnx_int8)



INT8 model saved to lstmnetvit_int8.onnx


In [25]:
onnx_uint8 = 'lstmnetvit_uint8.onnx'
quantize_dynamic(
    onnx_fp32,
    onnx_uint8,
    weight_type=QuantType.QUInt8,   # use signed INT8
)
print('INT8 model saved to', onnx_uint8)



INT8 model saved to lstmnetvit_uint8.onnx


## Pytorch Methods

In [32]:
model_fp32 = model
model_fp32.eval()
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

In [31]:
model_fp32 = model
model_fp32.eval()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
#model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
model_fp32_prepared(dummy_input)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

In [29]:
import torch
import torch.quantization as tq

def remove_weight_norm_recursive(module):
    for child_name, child in module.named_children():
        remove_weight_norm_recursive(child)
        try:
            # This will remove weight norm if applied
            nn.utils.remove_weight_norm(child)
        except (ValueError, AttributeError):
            pass

remove_weight_norm_recursive(model)

import torch.quantization as tq
import torch

torch.backends.quantized.engine = 'fbgemm'
model.eval()
model.qconfig = tq.get_default_qconfig('fbgemm')

model_prepared = tq.prepare(model, inplace=False)
# Run calibration (with your dummy_input)
with torch.no_grad():
    model_prepared(dummy_input)
model_quantized = tq.convert(model_prepared, inplace=False)

with torch.no_grad():
    out_quant, hidden_quant = model_quantized(dummy_input)
print('PyTorch UINT8 quantization complete.')

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

## 4. Validate FP32 vs. INT8 outputs

In [26]:
import numpy as np
def run_ort(sess, *inputs):
    inp_dict = {sess.get_inputs()[i].name: inputs[i].numpy() for i in range(len(inputs))}
    return sess.run(None, inp_dict)

sess_fp32 = ort.InferenceSession(onnx_fp32, providers=['CPUExecutionProvider'])
sess_int8 = ort.InferenceSession(onnx_int8, providers=['CPUExecutionProvider'])

out_fp32, _ = run_ort(sess_fp32, *dummy_input)
out_int8, _ = run_ort(sess_int8, *dummy_input)

diff = np.mean(np.abs(out_fp32 - out_int8))
print(f'Mean absolute difference FP32 vs. INT8: {diff:.6f}')

NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ConvInteger(10) node with name '/encoder_blocks.0/patchMerge/cn1/Conv_quant'