# SAM2 to ONNX/ORT Conversion

This notebook converts SAM2 (Segment Anything Model 2) checkpoints to ONNX and ORT
format for browser-based inference in NimbusImage.

It exports both **encoder** and **decoder** for the **Base+** and **Large** variants.
Encoders are then converted from `.onnx` to `.ort` (ORT's pre-optimized format),
which is required for the onnxruntime-web WebGPU backend.

After running all cells, download the encoder `.ort` and decoder `.onnx` files.

## 1. Clone repos and install dependencies

In [None]:
!git clone https://github.com/vietanhdev/samexporter.git
!git clone https://github.com/facebookresearch/sam2.git

In [None]:
# Pin PyTorch 2.5.1 to use the legacy TorchScript-based ONNX exporter.
# PyTorch 2.6+ defaults to the dynamo exporter, which runs onnxscript
# constant folding that gets stuck on Resize ops (hours of pure-Python pixel work).
!pip install torch==2.5.1 torchvision==0.20.1
!pip install -e ./sam2
!pip install onnx onnxruntime timm onnxsim

## 2. Download SAM2 checkpoints (Base+ and Large)

In [None]:
import os
os.makedirs('original_models', exist_ok=True)
os.makedirs('output_models', exist_ok=True)

!wget -q --show-progress -O original_models/sam2_hiera_base_plus.pt \
    https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt

!wget -q --show-progress -O original_models/sam2_hiera_large.pt \
    https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt

!ls -lh original_models/

## 3. Export Base+ model to ONNX

In [None]:
%cd samexporter

!python -m samexporter.export_sam2 \
    --checkpoint ../original_models/sam2_hiera_base_plus.pt \
    --output_encoder ../output_models/sam2_hiera_base_plus.encoder.onnx \
    --output_decoder ../output_models/sam2_hiera_base_plus.decoder.onnx \
    --model_type sam2_hiera_base_plus

%cd ..

## 4. Export Large model to ONNX

In [None]:
%cd samexporter

!python -m samexporter.export_sam2 \
    --checkpoint ../original_models/sam2_hiera_large.pt \
    --output_encoder ../output_models/sam2_hiera_large.encoder.onnx \
    --output_decoder ../output_models/sam2_hiera_large.decoder.onnx \
    --model_type sam2_hiera_large

%cd ..

## 5. Convert encoders from ONNX to ORT format

The onnxruntime-web WebGPU backend cannot load raw `.onnx` encoder models at
runtime (graph optimization fails silently). The `.ort` format has optimizations
pre-baked, so ORT can load it directly.

We use `onnxruntime.tools.convert_onnx_models_to_ort` to do this offline.

In [None]:
import onnxruntime as ort

for model_name in ['sam2_hiera_base_plus', 'sam2_hiera_large']:
    encoder_onnx = f'output_models/{model_name}.encoder.onnx'
    encoder_ort = f'output_models/{model_name}.encoder.ort'

    print(f'Converting {encoder_onnx} -> {encoder_ort} ...')

    # Use ORT_ENABLE_BASIC to apply standard graph optimizations (constant
    # folding, redundant node elimination) WITHOUT provider-specific transforms.
    # ORT_ENABLE_ALL would insert CPU-specific ops (e.g. com.microsoft.nchwc:Conv)
    # that the WebGPU backend doesn't support.
    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
    sess_options.optimized_model_filepath = encoder_ort

    # Creating the session with an optimized_model_filepath saves the .ort file
    ort.InferenceSession(encoder_onnx, sess_options, providers=['CPUExecutionProvider'])

    ort_size = os.path.getsize(encoder_ort) / 1024 / 1024
    print(f'  Saved {encoder_ort} ({ort_size:.1f} MB)')

print()
print('Done! Encoder .ort files ready.')
!ls -lh output_models/

## 6. Check output files

In [None]:
!ls -lh output_models/

## 8. Where to put the files

After downloading, place the files in your NimbusImage project:

```
public/onnx-models/sam/sam2_hiera_base_plus/encoder.ort      (from sam2_hiera_base_plus.encoder.ort)
public/onnx-models/sam/sam2_hiera_base_plus/decoder.onnx     (from sam2_hiera_base_plus.decoder.onnx)
public/onnx-models/sam/sam2_hiera_large/encoder.ort          (from sam2_hiera_large.encoder.ort)
public/onnx-models/sam/sam2_hiera_large/decoder.onnx         (from sam2_hiera_large.decoder.onnx)
```

Note: Encoders use `.ort` format (pre-optimized) because onnxruntime-web's WebGPU
backend cannot optimize raw `.onnx` encoder models at runtime.

The SAM1 ViT-B model files (`public/onnx-models/sam/vit_b/encoder.onnx` and
`decoder.onnx`) should already be in place from the original SAM1 setup.

In [None]:
from google.colab import files

for f in [
    'output_models/sam2_hiera_base_plus.encoder.ort',
    'output_models/sam2_hiera_base_plus.decoder.onnx',
    'output_models/sam2_hiera_large.encoder.ort',
    'output_models/sam2_hiera_large.decoder.onnx',
]:
    if os.path.exists(f):
        print(f'Downloading {f} ({os.path.getsize(f) / 1024 / 1024:.1f} MB)...')
        files.download(f)
    else:
        print(f'WARNING: {f} not found!')