# 1. Export IGEV++ to ONNX

Converts an IGEV++ PyTorch checkpoint to an ONNX model that can be loaded
by `ag-cam-tools depth-preview --stereo-backend onnx`.

The exported model expects two float32 inputs of shape `[1, 3, H, W]` in
`[0, 255]` range and produces a float32 disparity map.

**This export is tied to a specific resolution and max_disp.** If you
recalibrate or change binning, re-run this notebook.

---

### Prerequisites

1. Clone the IGEV++ repository:
   ```bash
   git clone https://github.com/gangweiX/IGEV-plusplus.git ~/code/IGEV-plusplus
   ```

2. Download the SceneFlow checkpoint from the
   [Google Drive folder](https://drive.google.com/drive/folders/1eubNsu03MlhUfTtrbtN7bfAsl39s2ywJ):
   ```bash
   mkdir -p ~/code/IGEV-plusplus/pretrained_models/igev_plusplus
   pip install gdown
   gdown --fuzzy "https://drive.google.com/..." \
       -O ~/code/IGEV-plusplus/pretrained_models/igev_plusplus/sceneflow.pth
   ```

3. Have a completed calibration session (from `ag-cam-tools calibration-capture`
   + `2.Calibration.ipynb`).

### Setup (one-time)
```bash
cd backends
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements-onnx-export.txt
python -m ipykernel install --user --name agrippa-export --display-name "agrippa-export"
```

Then select the **agrippa-export** kernel in Jupyter before running.

## Dependencies

In [1]:
%pip install -r requirements-onnx-export.txt -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import contextlib
import json
import os
import struct
import sys
import time

import numpy as np
import torch

## Configuration

Set these paths to match your setup.  `calibration_session` should point
to a calibrated session folder containing `calib_result/calibration_meta.json`
and `calib_result/remap_left.bin`.

In [3]:
# Path to cloned IGEV-plusplus repo root (contains core/igev_stereo.py).
igev_src = os.path.expanduser("~/code/IGEV-plusplus")

# Path to pre-trained checkpoint.
checkpoint_path = os.path.join(
    igev_src, "pretrained_models/igev_plusplus/sceneflow.pth")

# Calibration session folder.
calibration_session = "../calibration/calibration_20260225_160029_3cb907d2"

# GRU refinement iterations (12 is a good balance; upstream default is 16).
iters = 12

# Output ONNX file path.
output_path = "../models/igev_plusplus.onnx"

# Optional overrides (set to None to auto-detect from calibration).
override_max_disp = None   # e.g. 192
override_width    = None   # e.g. 720
override_height   = None   # e.g. 540

## Load calibration metadata

Reads `calibration_meta.json` to derive `max_disp` (cost volume size)
and the remap table header to get the image dimensions.

In [None]:
# --- Read max_disp from calibration_meta.json ---
meta_path = os.path.join(calibration_session, "calib_result",
                         "calibration_meta.json")
with open(meta_path, "r") as f:
    meta = json.load(f)

dr = meta.get("disparity_range", {})
num_disparities = dr.get("num_disparities", 128)

# IGEV++ builds a 3D cost volume with D = max_disp // 4 disparity levels.
# patch0 (stride 2) and patch1 (stride 4) downsample D for the multi-scale
# geometry volumes, and each is fed through a 3-level 3D hourglass that
# halves D three times then doubles it back.  For the deconv skip
# connections to align, every hourglass input D must be divisible by 8.
#
# With max_disp=256 → D=64:
#   cost_agg0: D=48 (sliced)  → 48/8=6 ✓
#   cost_agg1: patch0(64)→32  → 32/8=4 ✓
#   cost_agg2: patch1(64)→16  → 16/8=2 ✓
max_disp = override_max_disp or max(num_disparities, 256)

print(f"max_disp = {max_disp}  (calibration num_disparities = {num_disparities})")

# --- Read image dimensions from remap table header ---
# Binary format: "RMAP" (4 bytes) + width(u32) + height(u32) + flags(u32)
width = override_width
height = override_height

if width is None or height is None:
    remap_path = os.path.join(calibration_session, "calib_result",
                              "remap_left.bin")
    with open(remap_path, "rb") as f:
        magic = f.read(4)
        assert magic == b"RMAP", f"Bad remap magic: {magic!r}"
        w, h, _flags = struct.unpack("<III", f.read(12))
        if width is None:
            width = w
        if height is None:
            height = h

def pad_to_32(dim):
    """Round up to nearest multiple of 32."""
    return ((dim + 31) // 32) * 32

pad_w = pad_to_32(width)
pad_h = pad_to_32(height)

print(f"Frame dimensions: {width}x{height}")
print(f"Padded (32-divisible): {pad_w}x{pad_h}")

## Patch autocast for CPU tracing

IGEV++ uses `torch.cuda.amp.autocast` in its forward pass, which breaks
ONNX tracing on CPU-only machines.  We replace it with a no-op context
manager.

In [5]:
class _NoOpAutocast:
    """Drop-in replacement for torch.cuda.amp.autocast."""
    def __init__(self, *args, **kwargs):
        pass
    def __enter__(self):
        return self
    def __exit__(self, *args):
        pass
    def __call__(self, func):
        return func

if hasattr(torch.cuda, "amp"):
    torch.cuda.amp.autocast = _NoOpAutocast

if hasattr(torch, "autocast"):
    _orig_autocast = torch.autocast
    def _patched_autocast(device_type="cuda", *args, **kwargs):
        if device_type == "cuda":
            return contextlib.nullcontext()
        return _orig_autocast(device_type, *args, **kwargs)
    torch.autocast = _patched_autocast

print("Patched autocast for CPU tracing.")

Patched autocast for CPU tracing.


## Load IGEV++ model

Imports the IGEV++ model from the cloned source, builds the configuration
args it expects, and loads the pre-trained checkpoint.

In [None]:
# Add IGEV++ source to path.
sys.path.insert(0, igev_src)

# Try known import paths.
try:
    from core.igev_stereo import IGEVStereo
    print(f"Imported IGEVStereo from core.igev_stereo")
except ImportError:
    from igev_stereo import IGEVStereo
    print(f"Imported IGEVStereo from igev_stereo")

# Build args namespace that IGEV++ expects.
#
# The disparity range parameters control how many regression values each
# multi-scale head produces:  disp_range / disp_interval == D for that head.
# With max_disp=256 → cost volume D = 64:
#   cost_agg0: slice first 48 of 64             → D=48 → s_disp_range/1 = 48 ✓
#   cost_agg1: patch0 (stride 2) on first 64    → D=32 → m_disp_range/2 = 32 ✓
#   cost_agg2: patch1 (stride 4) on all 64      → D=16 → l_disp_range/4 = 16 ✓
class ModelArgs:
    pass

args = ModelArgs()
args.max_disp = max_disp
args.valid_iters = iters
args.mixed_precision = False
args.precision_dtype = "float32"
args.corr_implementation = "reg"
args.shared_backbone = False
args.corr_levels = 2
args.corr_radius = 4
args.n_downsample = 2
args.n_gru_layers = 3
args.slow_fast_gru = False
args.hidden_dims = [128] * 3
args.s_disp_range = 48
args.m_disp_range = 64
args.l_disp_range = 64
args.s_disp_interval = 1
args.m_disp_interval = 2
args.l_disp_interval = 4

model = IGEVStereo(args)
print(f"Model instantiated (max_disp={max_disp}, iters={iters})")

# Load checkpoint.
# weights_only=False is required — the upstream checkpoint was saved with
# older PyTorch and contains objects that the safe unpickler rejects.
print(f"Loading checkpoint: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

# Handle DataParallel wrapper prefix.
if any(k.startswith("module.") for k in state_dict.keys()):
    state_dict = {k.replace("module.", ""): v
                  for k, v in state_dict.items()}
    print("Stripped 'module.' prefix from state_dict.")

model.load_state_dict(state_dict, strict=False)
model.eval()
print("Checkpoint loaded.")

## Export to ONNX

Traces the model at the padded resolution with `torch.onnx.export` at
opset 16.  The dummy inputs use `[0, 255]` float32 range (not `[0, 1]`)
which matches what the C backend will send.

In [None]:
# Ensure output directory exists.
out_dir = os.path.dirname(output_path)
if out_dir:
    os.makedirs(out_dir, exist_ok=True)

# IGEV++ expects [0, 255] float32 input.
left_dummy = torch.randn(1, 3, pad_h, pad_w) * 128.0 + 128.0
right_dummy = torch.randn(1, 3, pad_h, pad_w) * 128.0 + 128.0

# Set test_mode to get single output (final refinement only).
model.test_mode = True

print(f"Tracing with opset 16 at {pad_w}x{pad_h} (iters={iters})...")
t0 = time.time()

with torch.no_grad():
    torch.onnx.export(
        model,
        (left_dummy, right_dummy),
        output_path,
        opset_version=16,
        input_names=["left", "right"],
        output_names=["disparity"],
        dynamic_axes=None,  # fixed resolution for best performance
        do_constant_folding=True,
        dynamo=False,  # force legacy JIT-trace exporter (dynamo can't handle IGEV++)
    )

dt = time.time() - t0
size_mb = os.path.getsize(output_path) / (1024 * 1024)
print(f"Raw ONNX exported in {dt:.1f}s ({size_mb:.1f} MB)")

## Simplify with onnxsim

Runs graph optimizations (constant folding, shape inference, dead-node
elimination) to reduce model size and improve inference speed.

In [None]:
import onnx
from onnxsim import simplify

print("Simplifying with onnxsim...")
t0 = time.time()

onnx_model = onnx.load(output_path)
model_sim, check = simplify(onnx_model)

if check:
    onnx.save(model_sim, output_path)
    dt = time.time() - t0
    size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"Simplified in {dt:.1f}s ({size_mb:.1f} MB)")
else:
    print("onnxsim check failed, keeping unsimplified model.")

## Validate with ONNX Runtime

Loads the exported model in ONNX Runtime and runs a single inference pass
to verify the model is valid and produces reasonable output.

In [None]:
import onnxruntime as ort

session = ort.InferenceSession(output_path,
                               providers=["CPUExecutionProvider"])

inputs = session.get_inputs()
outputs = session.get_outputs()
print(f"Inputs:  {[(i.name, i.shape, i.type) for i in inputs]}")
print(f"Outputs: {[(o.name, o.shape, o.type) for o in outputs]}")

# Warm-up / validation inference.
left = np.random.uniform(0, 255, (1, 3, pad_h, pad_w)).astype(np.float32)
right = np.random.uniform(0, 255, (1, 3, pad_h, pad_w)).astype(np.float32)

t0 = time.time()
results = session.run(None, {inputs[0].name: left, inputs[1].name: right})
dt = time.time() - t0

disp = results[-1].squeeze()
print(f"\nInference time: {dt:.2f}s")
print(f"Output shape: {disp.shape}")
print(f"Disparity range: [{disp.min():.2f}, {disp.max():.2f}]")

if disp.shape != (pad_h, pad_w):
    print(f"\nWARNING: output shape {disp.shape} != expected ({pad_h}, {pad_w})")
else:
    print("\nValidation OK.")

## Done

The ONNX model has been exported, simplified, and validated.

### Next steps

1. Copy the model to your target machine:
   ```bash
   scp models/igev_plusplus.onnx jetson:~/agrippa-stereocam/models/
   ```

2. Run depth preview:
   ```bash
   ag-cam-tools depth-preview \
       --rectify calibration/calibration_YYYYMMDD_HHMMSS \
       --stereo-backend onnx \
       --model-path models/igev_plusplus.onnx
   ```

The C backend automatically selects the best execution provider
(CUDA > CoreML > CPU) and handles padding/preprocessing.