In [None]:
#| default_exp export.onnx_exporter

In [None]:
#| include: false
from nbdev.showdoc import *

# ONNX Export

Export PyTorch models to ONNX format for deployment. Supports:
- Basic ONNX export with graph optimization
- Dynamic INT8 quantization (no calibration needed)
- Static INT8 quantization (with calibration data)
- Output verification against original model

In [None]:
#| export
from __future__ import annotations

import warnings
from functools import lru_cache
from pathlib import Path
from typing import Iterable

import numpy as np
import torch
import torch.nn as nn

In [None]:
#| export
@lru_cache(maxsize=None)
def _has_package(name: str) -> bool:
    "Check if a package is available (cached)"
    from importlib.util import find_spec
    return find_spec(name) is not None


def _require(*packages: str, install_hint: str | None = None) -> None:
    "Raise ImportError if any package is missing"
    missing = [p for p in packages if not _has_package(p)]
    if missing:
        hint = install_hint or f"pip install {' '.join(missing)}"
        raise ImportError(f"Missing packages: {missing}. Install with: {hint}")

## Export Function

In [None]:
#| export
def export_onnx(
    model: nn.Module,                     # PyTorch model to export
    sample: torch.Tensor,                 # Example input for tracing (with batch dim)
    output_path: str | Path,              # Output .onnx file path
    *,
    opset_version: int = 17,              # ONNX opset version (17 recommended for compatibility)
    quantize: bool = False,               # Apply INT8 quantization after export
    quantize_mode: str = "dynamic",       # "dynamic" (no calibration) or "static"
    calibration_data: Iterable | None = None,  # DataLoader for static quantization
    optimize: bool = True,                # Run ONNX graph optimizer
    dynamic_batch: bool = True,           # Allow variable batch size at runtime
    input_names: list[str] | None = None, # Names for input tensors
    output_names: list[str] | None = None,# Names for output tensors
) -> Path:
    "Export a PyTorch model to ONNX format with optional quantization"
    _require("onnx", install_hint="pip install onnx onnxruntime")
    import onnx
    
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Defaults
    input_names = input_names or ["input"]
    output_names = output_names or ["output"]
    
    # Quantization requires fixed batch size for shape inference
    dynamic_axes = None
    if dynamic_batch and not quantize:
        dynamic_axes = {
            input_names[0]: {0: "batch_size"},
            output_names[0]: {0: "batch_size"},
        }
    
    # Export to ONNX using legacy TorchScript exporter for better operator coverage
    # The new dynamo-based exporter (PyTorch 2.x default) has limited op support
    model.eval()
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        warnings.filterwarnings("ignore", category=UserWarning)
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        torch.onnx.export(
            model, sample, str(output_path),
            opset_version=opset_version,
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            dynamo=False,  # Use legacy TorchScript exporter for broader op support
        )
    
    # Optimize the graph (optional)
    if optimize and _has_package("onnxoptimizer"):
        import onnxoptimizer
        onnx_model = onnx.load(str(output_path))
        onnx_model = onnxoptimizer.optimize(onnx_model)
        onnx.save(onnx_model, str(output_path))
    
    # Apply quantization if requested
    if quantize:
        output_path = _quantize_onnx(output_path, quantize_mode, calibration_data, input_names[0])
    
    return output_path

In [None]:
#| export
def _quantize_onnx(
    onnx_path: Path,
    mode: str,
    calibration_data: Iterable | None,
    input_name: str,
) -> Path:
    "Apply INT8 quantization to an ONNX model"
    _require("onnxruntime", install_hint="pip install onnxruntime")
    
    from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static, shape_inference
    
    # Preprocess for shape inference
    preprocessed = onnx_path.with_stem(f"{onnx_path.stem}_preprocessed")
    shape_inference.quant_pre_process(str(onnx_path), str(preprocessed))
    
    quantized = onnx_path.with_stem(f"{onnx_path.stem}_int8")
    
    if mode == "dynamic":
        quantize_dynamic(str(preprocessed), str(quantized), weight_type=QuantType.QUInt8)
    elif mode == "static":
        if calibration_data is None:
            raise ValueError("Static quantization requires calibration_data")
        
        from onnxruntime.quantization import CalibrationDataReader
        
        class _DataReader(CalibrationDataReader):
            def __init__(self, data_iter, name):
                self.it, self.name = iter(data_iter), name
            def get_next(self):
                try:
                    batch = next(self.it)
                    if isinstance(batch, (tuple, list)): batch = batch[0]
                    return {self.name: batch.numpy()}
                except StopIteration:
                    return None
        
        quantize_static(
            str(preprocessed), str(quantized),
            calibration_data_reader=_DataReader(calibration_data, input_name),
            quant_format=QuantFormat.QDQ,
            activation_type=QuantType.QUInt8,
            weight_type=QuantType.QInt8,
        )
    else:
        raise ValueError(f"Unknown quantize_mode: {mode}. Use 'dynamic' or 'static'.")
    
    preprocessed.unlink(missing_ok=True)
    return quantized

## Inference Wrapper

In [None]:
#| export
class ONNXModel:
    "Wrapper for ONNX Runtime inference with PyTorch-like interface"
    
    def __init__(self, path: str | Path, device: str = "cpu"):
        _require("onnxruntime", install_hint="pip install onnxruntime")
        import onnxruntime as ort
        
        self.path, self.device = Path(path), device
        providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
        self.session = ort.InferenceSession(str(self.path), providers=providers)
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name
    
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        "Run inference on input tensor"
        x_np = x.cpu().numpy() if x.is_cuda else x.numpy()
        result = torch.from_numpy(self.session.run([self.output_name], {self.input_name: x_np})[0])
        return result.cuda() if self.device == "cuda" else result
    
    def warmup(self, sample: torch.Tensor, n: int = 10) -> None:
        "Run warmup iterations to stabilize inference timing"
        for _ in range(n): self(sample)
    
    def __repr__(self) -> str:
        return f"ONNXModel(path='{self.path}', device='{self.device}')"

## Verification

In [None]:
#| export
def verify_onnx(
    model: nn.Module,        # Original PyTorch model
    onnx_path: str | Path,   # Path to exported ONNX model
    sample: torch.Tensor,    # Test input tensor
    rtol: float = 1e-3,      # Relative tolerance
    atol: float = 1e-5,      # Absolute tolerance
) -> bool:
    "Verify ONNX model outputs match PyTorch model within tolerance"
    model.eval()
    with torch.no_grad():
        pt_out = model(sample).cpu().numpy()
    onnx_out = ONNXModel(onnx_path)(sample.cpu()).numpy()
    return np.allclose(pt_out, onnx_out, rtol=rtol, atol=atol)

## Usage Examples

```python
from fasterai.export.all import export_onnx, ONNXModel, verify_onnx

# Basic export
path = export_onnx(model, sample, "model.onnx")

# With quantization
path = export_onnx(model, sample, "model.onnx", quantize=True)

# Inference
onnx_model = ONNXModel("model.onnx")
output = onnx_model(input_tensor)

# Verify
assert verify_onnx(model, "model.onnx", sample)
```

In [None]:
show_doc(export_onnx)

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0
W0202 11:00:43.171000 421313 site-packages/torch/utils/cpp_extension.py:117] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda-12.8'


Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/export/onnx_exporter.py#L34){target="_blank" style="float:right; font-size:smaller"}

### export_onnx

```python

def export_onnx(
    model:nn.Module, # PyTorch model to export
    sample:torch.Tensor, # Example input for tracing (with batch dim)
    output_path:str | Path, # Output .onnx file path
    opset_version:int=18, # ONNX opset version
    quantize:bool=False, # Apply INT8 quantization after export
    quantize_mode:str='dynamic', # "dynamic" (no calibration) or "static"
    calibration_data:Iterable | None=None, # DataLoader for static quantization
    optimize:bool=True, # Run ONNX graph optimizer
    dynamic_batch:bool=True, # Allow variable batch size at runtime
    input_names:list[str] | None=None, # Names for input tensors
    output_names:list[str] | None=None, # Names for output tensors
)->Path:


```

*Export a PyTorch model to ONNX format with optional quantization*

In [None]:
show_doc(ONNXModel)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/export/onnx_exporter.py#L145){target="_blank" style="float:right; font-size:smaller"}

### ONNXModel

```python

def ONNXModel(
    path:str | Path, device:str='cpu'
):


```

*Wrapper for ONNX Runtime inference with PyTorch-like interface*

In [None]:
show_doc(verify_onnx)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/export/onnx_exporter.py#L172){target="_blank" style="float:right; font-size:smaller"}

### verify_onnx

```python

def verify_onnx(
    model:nn.Module, # Original PyTorch model
    onnx_path:str | Path, # Path to exported ONNX model
    sample:torch.Tensor, # Test input tensor
    rtol:float=0.001, # Relative tolerance
    atol:float=1e-05, # Absolute tolerance
)->bool:


```

*Verify ONNX model outputs match PyTorch model within tolerance*