# Heterogeneous Mini‑Models – Export & Quantize to ONNX

This standalone notebook builds **four neural‑network toy models**, each illustrating a different
mix of operators (Conv, custom Mix‑Transformer attention, LSTM, Dense). It then
exports every network to **ONNX opset 17** and produces an **INT8 weight‑only** version
using ONNX Runtime’s dynamic quantizer.

To make the code easily hackable, *all logic lives in this notebook* – no extra
files needed. GPU is **not** required; CPU is fine.


## 0. Install required packages

In [None]:
!pip install -q torch==2.2 onnx onnxruntime onnxruntime-tools

## 1. Mix‑Transformer building blocks  
The custom encoder layer below follows the *MiT* paper ("Mix Transformer: A Hybrid Strategy for Vision Transformers").  
Four helper classes work together:

1. **`OverlapPatchMerging`** – slides a convolution with stride > 1 to
   turn the input image into overlapping patches **and** reduce the spatial
   resolution. The conv’s output `(B,C,H',W')` is flattened into
   `(B, N, C)` where `N = H'·W'`, then LayerNorm is applied along the channel
   dimension.
2. **`EfficientSelfAttention`** – two‑stage self‑attention:
   * *Reduction stage* – a strided conv downsamples keys & values so the
 	quadratic `N×N` cost is lowered.
   * *Attention stage* – queries come from the **original** tokens, while
 	keys/values come from the reduced set. A classic scaled dot‑product
 	with softmax is followed by a final projection.
3. **`MixFFN`** – feed‑forward network that first expands channels with an
   MLP, applies a **depth‑wise 3×3 conv** (injecting local inductive bias),
   then projects back to the original width.
4. **`MixTransformerEncoderLayer`** – wraps everything: patch merging once,
   followed by *N* residual **(Attention → MixFFN → LayerNorm)** blocks.

The implementation matches the code you provided; extra inline comments were
added for clarity.


In [1]:

import torch
import torch.nn as nn

# ---------- 1. Overlapping patch embedding ----------
class OverlapPatchMerging(nn.Module):
	def __init__(self, in_channels, out_channels, patch_size, stride, padding):
		super().__init__()
		# Conv extracts overlapped patches and does down‑sampling.
		self.cn1 = nn.Conv2d(in_channels, out_channels,
							kernel_size=patch_size,
							stride=stride, padding=padding)
		self.layerNorm = nn.LayerNorm(out_channels)

	def forward(self, patches):
		"""
		Args:
			patches: (B, C, H, W)
		Returns:
			x : (B, N, C_out)   flattened sequence
			H', W' : new spatial dims
		"""
		x = self.cn1(patches)       	# (B,C_out,H',W')
		_,_,H, W = x.shape
		x = x.flatten(2).transpose(1,2) # (B, N, C_out) where N = H'*W'
		x = self.layerNorm(x)
		return x, H, W

# ---------- 2. Token‑reduced self‑attention ----------
class EfficientSelfAttention(nn.Module):
	def __init__(self, channels, reduction_ratio, num_heads):
		super().__init__()
		assert channels % num_heads == 0, "channels must divide num_heads"
		self.heads = num_heads

		# Reduction: stride = reduction_ratio
		self.cn1 = nn.Conv2d(channels, channels,
							kernel_size=reduction_ratio,
							stride=reduction_ratio)
		self.ln1 = nn.LayerNorm(channels)

		# Attention projections
		self.keyValueExtractor = nn.Linear(channels, channels * 2)
		self.query         	= nn.Linear(channels, channels)

		self.smax   	= nn.Softmax(dim=-1)
		self.finalLayer = nn.Linear(channels, channels)

	def forward(self, x, H, W):
		"""
			Args:
				x : (B, N, C) with N = H*W
				H, W : spatial size before flattening
			Returns:
				(B, N, C)
		"""
		B,N,C = x.shape

		# 1. reduce tokens for K,V ------------------------------------------
		x1 = x.permute(0,2,1).reshape(B,C,H,W)  # (B,C,H,W)
		x1 = self.cn1(x1)                   	# (B,C,H/rr,W/rr)
		x1 = x1.reshape(B,C,-1).permute(0,2,1).contiguous() # (B,N',C)
		x1 = self.ln1(x1)

		# 2. project to Q,K,V -----------------------------------------------
		kv = self.keyValueExtractor(x1)     	# (B,N',2C)
		kv = kv.reshape(B,-1,2,self.heads,C//self.heads)
		kv = kv.permute(2,0,3,1,4)          	# (2,B,h,N',c/h)
		k, v = kv[0], kv[1]

		q = self.query(x).reshape(B,N,self.heads,C//self.heads)
		q = q.permute(0,2,1,3)              	# (B,h,N,c/h)

		# 3. scaled dot‑product attention -----------------------------------
		dim_head = (C/self.heads) ** 0.5
		attn = self.smax(q @ k.transpose(-2,-1) / dim_head)  # (B,h,N,N')
		ctx  = (attn @ v).transpose(1,2).reshape(B,N,C)  	# (B,N,C)

		return self.finalLayer(ctx)         	# (B,N,C)

# ---------- 3. Feed‑forward network with depth‑wise conv ----------
class MixFFN(nn.Module):
	def __init__(self, channels, expansion_factor):
		super().__init__()
		exp = channels * expansion_factor
		self.mlp1 = nn.Linear(channels, exp)
		self.depthwise = nn.Conv2d(exp, exp, 3,
								padding=1, groups=channels)
		self.gelu = nn.GELU()
		self.mlp2 = nn.Linear(exp, channels)

	def forward(self, x, H, W):
		"""
		Args:
			x : (B, N, C)
		"""
		x = self.mlp1(x)                    	# (B,N,exp)
		B,N,C = x.shape
		x = x.transpose(1,2).view(B,C,H,W)  	# (B,exp,H,W)
		x = self.gelu(self.depthwise(x).flatten(2).transpose(1,2))
		return self.mlp2(x)                 	# (B,N,C)

# ---------- 4. Full encoder layer ------------------------------------------
class MixTransformerEncoderLayer(nn.Module):
	def __init__(self, in_channels, out_channels, patch_size, stride, padding,
             	n_layers, reduction_ratio, num_heads, expansion_factor):
		super().__init__()
		self.patchMerge = OverlapPatchMerging(in_channels, out_channels,
											patch_size, stride, padding)
		self._attn  = nn.ModuleList(
			[EfficientSelfAttention(out_channels, reduction_ratio, num_heads)
			for _ in range(n_layers)])
		self._ffn   = nn.ModuleList(
			[MixFFN(out_channels, expansion_factor) for _ in range(n_layers)])
		self._lNorm = nn.ModuleList(
			[nn.LayerNorm(out_channels) for _ in range(n_layers)])

	def forward(self, x):
		"""
		Args:
			x : (B, C_in, H, W)
		Returns:
			(B, C_out, H', W')
		"""
		B,C,H,W = x.shape
		x, H, W = self.patchMerge(x)        	# (B,N,C_out)
		for attn, ffn, ln in zip(self._attn, self._ffn, self._lNorm):
			x = ln(x + attn(x, H, W))       	# residual + norm
			x = ln(x + ffn(x, H, W))
		x = x.reshape(B, H, W, -1).permute(0,3,1,2).contiguous()
		return x


## 2. Four toy networks
Below we compose the encoder layer into four miniature models of increasing
complexity.

| Model | Branch‑1 | Branch‑2 | Recurrent | Head |
|-------|----------|----------|-----------|------|
| `HeteroVIT`   | Conv | **Mix‑Transformer** | LSTM | Dense |
| `HeteroDense` | Conv | Dense | LSTM | Dense |
| `CNN2_LSTM`   | Conv×2 | — | LSTM | Dense |
| `CNN1_Dense`  | Conv | — | — | Dense |

All share the same input shape **(B, 1, 60, 90)** to keep ONNX export simple.


In [5]:

import torch
import torch.nn as nn

# ---------- variant A -------------------------------------------------------
class HeteroVIT(nn.Module):
	def __init__(self, lstm_h=64, n_classes=10):
		super().__init__()
		self.conv = nn.Conv2d(1, 8, 3, padding=1)
		self.vit  = MixTransformerEncoderLayer(1, 32, 4, 4, 0,
                                           	n_layers=1,
                                           	reduction_ratio=4,
                                           	num_heads=4,
                                           	expansion_factor=4)
		self.pool = nn.AdaptiveAvgPool2d((1,1))
		self.lstm = nn.LSTM(8*60*90 + 32, lstm_h, batch_first=True)
		self.fc   = nn.Linear(lstm_h, n_classes)

	def forward(self, x, h0=None):
		a = self.conv(x).flatten(1)
		p = self.pool(self.vit(x)).flatten(1)
		y,(h,c) = self.lstm(torch.cat([a,p],1).unsqueeze(1), h0)
		return self.fc(y.squeeze(1)), h, c

# ---------- variant B -------------------------------------------------------
class HeteroDense(nn.Module):
	def __init__(self, lstm_h=64, n_classes=10):
		super().__init__()
		self.conv = nn.Conv2d(1, 8, 3, padding=1)
		self.proj = nn.Linear(60*90, 32)
		self.lstm = nn.LSTM(8*60*90 + 32, lstm_h, batch_first=True)
		self.fc   = nn.Linear(lstm_h, n_classes)

	def forward(self, x, h0=None):
		a = self.conv(x).flatten(1)
		p = self.proj(x.flatten(1))
		y,(h,c) = self.lstm(torch.cat([a,p],1).unsqueeze(1), h0)
		return self.fc(y.squeeze(1)), h, c

# ---------- variant C -------------------------------------------------------
class CNN2_LSTM(nn.Module):
	def __init__(self, lstm_h=64, n_classes=10):
		super().__init__()
		self.conv1 = nn.Conv2d(1,8,3,padding=1)
		self.conv2 = nn.Conv2d(8,16,3,padding=1)
		self.lstm  = nn.LSTM(16*60*90, lstm_h, batch_first=True)
		self.fc	= nn.Linear(lstm_h, n_classes)

	def forward(self, x, h0=None):
		f = self.conv2(self.conv1(x)).flatten(1)
		y,(h,c) = self.lstm(f.unsqueeze(1), h0)
		return self.fc(y.squeeze(1)), h, c

# ---------- variant D -------------------------------------------------------
class CNN1_Dense(nn.Module):
	def __init__(self, n_classes=10):
		super().__init__()
		self.conv = nn.Conv2d(1,8,3,padding=1)
		self.fc   = nn.Linear(8*60*90, n_classes)

	def forward(self, x, *args):
		return self.fc(self.conv(x).flatten(1)), torch.Tensor(1), torch.Tensor(1)


## 3. Export to ONNX & dynamic INT8 quantization
 * ONNX **opset 17** keeps the LSTM and custom attention intact.
 * **Dynamic** quantization converts weights to INT8 (`QInt8`), leaving
   activations in FP32 – no calibration data needed.


In [6]:

from pathlib import Path
from onnxruntime.quantization import quantize_dynamic, QuantType
import torch

MODELS = {
	"hetero_vit"  : HeteroVIT(),
	"hetero_dense": HeteroDense(),
	"cnn2_lstm"   : CNN2_LSTM(),
	"cnn1_dense"  : CNN1_Dense(),
}

out_dir = Path("models/dummy"); out_dir.mkdir(exist_ok=True)
example = torch.randn(1,1,60,90)
opset   = 17

for name, net in MODELS.items():
	fp32 = out_dir / f"{name}.onnx"
	int8 = out_dir / f"{name}_int8.onnx"

	try:
		output_example = net(example)
		print("output shape", output_example[0].shape)
	except Exception as e:
		print("model forward failed", e)
		continue

	# Export the model to ONNX format
	try:
		torch.onnx.export(net.eval(), example, fp32,
						opset_version=opset,
						input_names=["input"],
						output_names=["logits","h","c"],
						do_constant_folding=True)
		print("saved", fp32)
	except Exception as e:
		print("onnx export failed", e)
		continue

	# Quantize the model
	try:
		quantize_dynamic(fp32, int8,
						weight_type=QuantType.QInt8)
		print("saved", int8)
	except Exception as e:
		print("onnx quantization failed", e)
		continue


output shape torch.Size([1, 10])




saved models/dummy/hetero_vit.onnx
saved models/dummy/hetero_vit_int8.onnx
output shape torch.Size([1, 10])




saved models/dummy/hetero_dense.onnx
saved models/dummy/hetero_dense_int8.onnx
output shape torch.Size([1, 10])




saved models/dummy/cnn2_lstm.onnx


  return self.fc(self.conv(x).flatten(1)), torch.Tensor(1), torch.Tensor(1)


saved models/dummy/cnn2_lstm_int8.onnx
output shape torch.Size([1, 10])
saved models/dummy/cnn1_dense.onnx
saved models/dummy/cnn1_dense_int8.onnx
