# 2. Export FoundationStereo to ONNX

Converts a FoundationStereo PyTorch model to an ONNX model that can be
loaded by `ag-cam-tools depth-preview --stereo-backend onnx`.

The exported model expects two float32 inputs of shape `[1, 3, H, W]` and
produces a float32 disparity map.

**This export is tied to a specific resolution.** If you recalibrate or
change binning, re-run this notebook.

---

### Prerequisites

1. Clone the FoundationStereo repository:
   ```bash
   git clone https://github.com/NVlabs/FoundationStereo.git ~/code/FoundationStereo
   ```

2. Download the pre-trained checkpoint (see the FoundationStereo README for links).

3. Have a completed calibration session (from `ag-cam-tools calibration-capture`
   + `2.Calibration.ipynb`).

### Setup (one-time)
```bash
cd backends
python3 -m venv .venv
source .venv/bin/activate
pip install torch torchvision onnx onnxsim onnxruntime numpy
python -m ipykernel install --user --name agrippa-export --display-name "agrippa-export"
```

Then select the **agrippa-export** kernel in Jupyter before running.

> **Note:** FoundationStereo may have additional dependencies (e.g. `timm`,
> `einops`).  Install them as needed when the model import cell runs.

## Dependencies

In [None]:
%pip install torch torchvision onnx onnxsim onnxruntime numpy -q

In [None]:
import json
import os
import struct
import sys
import time

import numpy as np
import torch

## Configuration

Set these paths to match your setup.

In [None]:
# Path to cloned FoundationStereo repo root.
foundation_src = os.path.expanduser("~/code/FoundationStereo")

# Path to pre-trained checkpoint.
checkpoint_path = os.path.join(foundation_src, "pretrained/model.pth")

# Calibration session folder.
calibration_session = "../calibration/calibration_YYYYMMDD_HHMMSS"

# Output ONNX file path.
output_path = "../models/foundation_stereo.onnx"

# Optional overrides (set to None to auto-detect from calibration).
override_width  = None   # e.g. 720
override_height = None   # e.g. 540

## Load calibration metadata

Reads the remap table header to get the image dimensions.

In [None]:
width = override_width
height = override_height

if width is None or height is None:
    remap_path = os.path.join(calibration_session, "calib_result",
                              "remap_left.bin")
    with open(remap_path, "rb") as f:
        header = f.read(8)
        w, h = struct.unpack("<II", header)
        if width is None:
            width = w
        if height is None:
            height = h

def pad_to_32(dim):
    """Round up to nearest multiple of 32."""
    return ((dim + 31) // 32) * 32

pad_w = pad_to_32(width)
pad_h = pad_to_32(height)

print(f"Frame dimensions: {width}x{height}")
print(f"Padded (32-divisible): {pad_w}x{pad_h}")

## Load FoundationStereo model

Imports the model from the cloned source and loads the checkpoint.

> **Note:** FoundationStereo's import path and model API may vary between
> versions.  Adapt the import and instantiation below to match the version
> you cloned.  Check `README.md` in the FoundationStereo repo for the
> correct usage.

In [None]:
# Add FoundationStereo source to path.
sys.path.insert(0, foundation_src)

# --- Adapt this block to match your FoundationStereo version ---
# Common import patterns:
#   from core.foundation_stereo import FoundationStereo
#   from foundation_stereo import FoundationStereo
# If the repo uses a config file, load it here too.

try:
    from core.foundation_stereo import FoundationStereo
    print("Imported FoundationStereo from core.foundation_stereo")
except ImportError:
    try:
        from foundation_stereo import FoundationStereo
        print("Imported FoundationStereo from foundation_stereo")
    except ImportError:
        raise ImportError(
            f"Cannot import FoundationStereo from {foundation_src}. "
            f"Check the repo structure and adapt the import above.")

# Load checkpoint.
print(f"Loading checkpoint: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

# Handle DataParallel wrapper prefix.
if any(k.startswith("module.") for k in state_dict.keys()):
    state_dict = {k.replace("module.", ""): v
                  for k, v in state_dict.items()}
    print("Stripped 'module.' prefix from state_dict.")

# Instantiate and load weights.
# Adapt constructor args to match your version.
model = FoundationStereo()
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Model loaded.")

## Export to ONNX

Traces the model at the padded resolution.  FoundationStereo typically
expects `[0, 1]` float32 input (unlike IGEV++ which uses `[0, 255]`).

> **Important:** Check the FoundationStereo source for the expected input
> range.  If the model expects `[0, 1]`, you will need to adjust the C
> backend's preprocessing or include normalization in the ONNX graph.
> The C backend (`stereo_onnx.c`) currently sends `[0, 255]` values.

In [None]:
# Ensure output directory exists.
out_dir = os.path.dirname(output_path)
if out_dir:
    os.makedirs(out_dir, exist_ok=True)

# Create dummy inputs.
# Adjust the value range to match what FoundationStereo expects.
left_dummy = torch.randn(1, 3, pad_h, pad_w) * 0.5 + 0.5   # [0, 1] range
right_dummy = torch.randn(1, 3, pad_h, pad_w) * 0.5 + 0.5

print(f"Tracing with opset 16 at {pad_w}x{pad_h}...")
t0 = time.time()

with torch.no_grad():
    torch.onnx.export(
        model,
        (left_dummy, right_dummy),
        output_path,
        opset_version=16,
        input_names=["left", "right"],
        output_names=["disparity"],
        dynamic_axes=None,
        do_constant_folding=True,
    )

dt = time.time() - t0
size_mb = os.path.getsize(output_path) / (1024 * 1024)
print(f"Raw ONNX exported in {dt:.1f}s ({size_mb:.1f} MB)")

## Simplify and validate

In [None]:
import onnx
from onnxsim import simplify

print("Simplifying with onnxsim...")
t0 = time.time()

onnx_model = onnx.load(output_path)
model_sim, check = simplify(onnx_model)

if check:
    onnx.save(model_sim, output_path)
    dt = time.time() - t0
    size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"Simplified in {dt:.1f}s ({size_mb:.1f} MB)")
else:
    print("onnxsim check failed, keeping unsimplified model.")

In [None]:
import onnxruntime as ort

session = ort.InferenceSession(output_path,
                               providers=["CPUExecutionProvider"])

inputs = session.get_inputs()
outputs = session.get_outputs()
print(f"Inputs:  {[(i.name, i.shape, i.type) for i in inputs]}")
print(f"Outputs: {[(o.name, o.shape, o.type) for o in outputs]}")

# Validation inference.
left = np.random.uniform(0, 1, (1, 3, pad_h, pad_w)).astype(np.float32)
right = np.random.uniform(0, 1, (1, 3, pad_h, pad_w)).astype(np.float32)

t0 = time.time()
results = session.run(None, {inputs[0].name: left, inputs[1].name: right})
dt = time.time() - t0

disp = results[-1].squeeze()
print(f"\nInference time: {dt:.2f}s")
print(f"Output shape: {disp.shape}")
print(f"Disparity range: [{disp.min():.2f}, {disp.max():.2f}]")

if disp.shape != (pad_h, pad_w):
    print(f"\nWARNING: output shape {disp.shape} != expected ({pad_h}, {pad_w})")
else:
    print("\nValidation OK.")

## Done

The FoundationStereo ONNX model has been exported, simplified, and validated.

### Input range note

If FoundationStereo expects `[0, 1]` normalized input but the C backend
sends `[0, 255]`, you have two options:

1. **Prepend a division node to the ONNX graph** (recommended):
   Add `input / 255.0` as the first operation so the model accepts
   `[0, 255]` directly.  This can be done with `onnx.helper`.

2. **Modify `stereo_onnx.c`** to divide by 255 during preprocessing.

### Next steps

1. Copy the model to your target machine:
   ```bash
   scp models/foundation_stereo.onnx jetson:~/agrippa-stereocam/models/
   ```

2. Run depth preview:
   ```bash
   ag-cam-tools depth-preview \
       --rectify calibration/calibration_YYYYMMDD_HHMMSS \
       --stereo-backend onnx \
       --model-path models/foundation_stereo.onnx
   ```