<a href="https://colab.research.google.com/github/IoT-gamer/sam2-dinov3-onnx/blob/main/notebooks/edgetam_onnx_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## EdgeTAM ONNX Conversion
- only exporting encoder and decoder
  - memory attention is challenging to export
- inputs/outputs to onnx models use NumPy/OpenCV
  - PyTorch-independent

- image size is fixed at 1024x1024
- device = CPU or GPU

## References/Acknowledgments
- [Official PyTorch implementation of "EdgeTAM: On-Device Track Anything Model"](https://github.com/facebookresearch/EdgeTAM)

## Setup

### Clone Official EdgeTAM Repo

In [None]:
!git clone https://github.com/facebookresearch/EdgeTAM.git

In [None]:
%cd EdgeTAM

### Setup Device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### Install Dependencies

In [None]:
%env SAM2_BUILD_CUDA=0
!pip install -e .

In [None]:
!pip install onnx
if device == "cuda":
  !pip install onnxruntime-gpu
else:
  !pip install onnxruntime

### Constants

In [None]:
EDGETAM_ENCODER_PATH = "edgetam_encoder.onnx"
EDGETAM_DECODER_PATH = "edgetam_decoder.onnx"
EDGETAM_INPUT_SIZE = 1024 # currenty fixed

### Import Dependencies

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from typing import Any
import numpy as np
import cv2
import onnx
import onnxruntime as ort
from typing import Tuple
from sam2.build_sam import build_sam2

## Wrapper Classes and Export Function

In [None]:
torch.manual_seed(42)
np.random.seed(42)

# Define Model Wrapper Classes
class SAM2ImageEncoder(nn.Module):
    def __init__(self, sam_model: Any):
        super().__init__()
        self.image_encoder = sam_model.image_encoder
        self.sam_mask_decoder = sam_model.sam_mask_decoder
    def forward(self, x: torch.Tensor):
        backbone_out = self.image_encoder(x)
        backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
        backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
        return backbone_out["backbone_fpn"][-1], backbone_out["backbone_fpn"][-3], backbone_out["backbone_fpn"][-2], backbone_out["vision_pos_enc"][-1]

class SAM2ImageDecoder(nn.Module):
    def __init__(self, sam_model: Any, multimask_output: bool):
        super().__init__()
        self.mask_decoder = sam_model.sam_mask_decoder
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.multimask_output = multimask_output
        self.image_size = sam_model.image_size
    def forward(self, image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input):
        sparse_embedding = self._embed_points(point_coords, point_labels)
        dense_embedding = self._embed_masks(mask_input, has_mask_input)
        masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
            image_embeddings=image_embed, image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embedding, dense_prompt_embeddings=dense_embedding,
            repeat_image=False,
            high_res_features=[high_res_feats_0, high_res_feats_1]
        )
        if self.multimask_output:
            masks = masks[:, 1:, :, :]
            iou_predictions = iou_predictions[:, 1:]
        return masks, iou_predictions
    def _embed_points(self, point_coords, point_labels):
        point_coords = torch.cat([point_coords + 0.5, torch.zeros_like(point_coords[:, :1])], dim=1)
        point_labels = torch.cat([point_labels, -torch.ones_like(point_labels[:, :1])], dim=1)
        point_coords[..., 0] /= self.image_size
        point_coords[..., 1] /= self.image_size
        point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
        point_embedding = point_embedding * (point_labels != -1)
        point_embedding += self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
        for i in range(self.prompt_encoder.num_point_embeddings):
            point_embedding += self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
        return point_embedding
    def _embed_masks(self, input_mask, has_mask_input):
        mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)
        mask_embedding += (1 - has_mask_input) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
        return mask_embedding

# Conversion Function with dummy inputs
def convert_all_models_to_onnx(model_cfg, checkpoint_path, input_size, device, multimask_output):
    # Build the model
    sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
    print("✅ Model and checkpoint loaded successfully.")

    # Convert the model to ONNX
    print("\n[0/2] Converting EdgeTAM...")

    # Convert encoder
    sam2_encoder = SAM2ImageEncoder(sam2_model).to(device).eval()
    encoder_dummy_input = torch.randn(1, 3, input_size, input_size, device=device)
    print("\n[1/2] Converting Encoder...")
    torch.onnx.export(sam2_encoder, encoder_dummy_input, EDGETAM_ENCODER_PATH, opset_version=17,
                      input_names=['image'], output_names=['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'vision_pos_enc'], dynamo=False)
    print("✅ Encoder exported.")

   # Convert decoder
    sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).to(device).eval()
    image_embed, high_res_0, high_res_1, _ = sam2_encoder(encoder_dummy_input)
    decoder_dummy_inputs = (
        image_embed,
        high_res_0,
        high_res_1,
        torch.randint(0, input_size, (1, 1, 2), dtype=torch.float, device=device),
        torch.randint(0, 1, (1, 1), dtype=torch.float, device=device),
        torch.randn(1, 1, 256, 256, device=device),
        torch.tensor([1], dtype=torch.float, device=device)
    )
    print("\n[2/2] Converting Decoder...")
    torch.onnx.export(sam2_decoder, decoder_dummy_inputs, EDGETAM_DECODER_PATH, opset_version=17,
                      input_names=['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input'],
                      output_names=['low_res_masks', 'iou_predictions'], dynamic_axes={'point_coords': {1: 'num_points'}, 'point_labels': {1: 'num_points'}}, dynamo=False)
    print("✅ Decoder exported.")

    print("\n🎉 All models converted successfully!")


## Export models

In [None]:
convert_all_models_to_onnx(
    model_cfg="configs/edgetam.yaml",
    checkpoint_path="checkpoints/edgetam.pt",
    input_size=EDGETAM_INPUT_SIZE,
    device = device,
    multimask_output=True
)

## Static Image Test
- predict mask from single point

In [None]:
def mask_to_rgb(mask: np.ndarray) -> np.ndarray:
    """Converts a binary segmentation mask to a colorized RGB image."""
    # This creates a red overlay for the mask
    rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    rgb_mask[mask == 1] = [0, 255, 0] # Green for the segmented object
    return rgb_mask

def preprocess_image(image_array: np.ndarray, input_size: int = 1024) -> np.ndarray:
    """Resizes, pads, and normalizes an image for EdgeTAM ONNX model inference."""
    orig_height, orig_width, _ = image_array.shape
    resized_width, resized_height = input_size, input_size

    input_array_resized = cv2.resize(image_array, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)

    # Normalize with ImageNet stats
    mean = np.array([123.675, 116.28, 103.53])
    std = np.array([58.395, 57.12, 57.375])
    input_tensor = (input_array_resized - mean) / std

    # Transpose to CHW format and add batch dimension
    input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32)
    return input_tensor

def preprocess_point(
    point: np.ndarray,
    label: np.ndarray,
    orig_size: Tuple[int, int],
    resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
    """Preprocesses a point for EdgeTAM ONNX model inference."""
    orig_height, orig_width = orig_size
    resized_height, resized_width = resized_size

    onnx_coord = np.concatenate([point, np.array([[0.0, 0.0]])], axis=0)[None, :, :].astype(np.float32)
    onnx_label = np.concatenate([label, np.array([-1])])[None, :].astype(np.float32)

    # Scale coordinates to the resized image dimensions
    onnx_coord[..., 0] = onnx_coord[..., 0] * (resized_width / orig_width)
    onnx_coord[..., 1] = onnx_coord[..., 1] * (resized_height / orig_height)
    return onnx_coord, onnx_label

def run_inference(
    encoder_session: ort.InferenceSession,
    decoder_session: ort.InferenceSession,
    image_tensor: np.ndarray,
    point_coords: np.ndarray,
    point_labels: np.ndarray,
    original_size: Tuple[int, int]
) -> Tuple[np.ndarray, np.ndarray]:
    """Runs EdgeTAM inference and handles mask upscaling."""
    # Encoder inference
    encoder_outputs = encoder_session.run(None, {'image': image_tensor})
    image_embed, high_res_feats_0, high_res_feats_1, _ = encoder_outputs

    # Decoder inference
    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    decoder_outputs = decoder_session.run(None, {
        'image_embed': image_embed, 'high_res_feats_0': high_res_feats_0,
        'high_res_feats_1': high_res_feats_1, "point_coords": point_coords,
        "point_labels": point_labels, "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
    })
    low_res_masks, iou_predictions = decoder_outputs

    # Post-processing: Select the best mask and resize it
    best_mask_idx = np.argmax(iou_predictions[0])
    selected_low_res_mask = low_res_masks[0, best_mask_idx, :, :]

    # Use OpenCV to resize the mask to the original image's dimensions
    resized_mask = cv2.resize(
        selected_low_res_mask,
        (original_size[1], original_size[0]), # cv2 expects (width, height)
        interpolation=cv2.INTER_LINEAR
    )
    return resized_mask, iou_predictions

def overlay_transparent_mask(image_bgr, mask_rgb, binary_mask):
    """
    Overlays a mask with true transparency onto a BGR image.

    Args:
        image_bgr: The background BGR image (3 channels).
        mask_rgb: The mask image in RGB format (3 channels).
        binary_mask: The binary mask to use for transparency (1 channel, 0 or 1).

    Returns:
        The resulting BGR image with the transparent mask overlayed.
    """
    # Create an alpha channel from the binary mask.
    alpha = (binary_mask * 255).astype(image_bgr.dtype)

    # Create a 4-channel mask (RGB + Alpha).
    # We need to make sure the mask_rgb and alpha have the same dimensions for stacking.
    alpha_reshaped = cv2.merge([alpha, alpha, alpha])

    # Invert the alpha channel for the background.
    alpha_inv = cv2.bitwise_not(alpha)

    # Extract the parts of the background and mask.
    masked_background = cv2.bitwise_and(image_bgr, image_bgr, mask=alpha_inv)
    masked_overlay = cv2.bitwise_and(mask_rgb, mask_rgb, mask=alpha)

    # Combine the two parts.
    overlayed_frame = cv2.add(masked_background, masked_overlay)

    return overlayed_frame

import cv2
import numpy as np

def overlay_point(image_bgr, point_coords, color, radius, thickness=-1):
    """
    Overlays a point (as a filled circle) onto an image.

    Args:
        image_bgr (np.array): The background BGR image.
        point_coords (np.array): The coordinates of the point to draw, e.g., np.array([[500, 375]]).
        color (tuple): The BGR color for the point, e.g., (0, 0, 255) for red.
        radius (int): The radius of the circle representing the point.
        thickness (int): The thickness of the circle. Use -1 for a filled circle.

    Returns:
        np.array: The image with the point overlayed.
    """
    # Create a copy of the image to avoid modifying the original
    overlayed_frame = image_bgr.copy()

    # Get the coordinates from the numpy array
    x, y = point_coords[0][0], point_coords[0][1]

    # Draw the circle
    cv2.circle(overlayed_frame, (x, y), radius, color, thickness)

    return overlayed_frame


In [None]:
image_path = 'notebooks/images/truck.jpg'
input_point = np.array([[500, 375]])
input_label = np.array([1])

image_bgr = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
height, width, _ = image_rgb.shape
orig_size = (height, width)
input_tensor = preprocess_image(image_rgb, EDGETAM_INPUT_SIZE)
onnx_coord, onnx_label = preprocess_point(
    input_point, input_label, orig_size, (EDGETAM_INPUT_SIZE, EDGETAM_INPUT_SIZE)
)

edgetam_encoder_session = ort.InferenceSession(
    EDGETAM_ENCODER_PATH,
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
edgetam_decoder_session = ort.InferenceSession(
    EDGETAM_DECODER_PATH,
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
mask, _ = run_inference(
    edgetam_encoder_session, edgetam_decoder_session,
    input_tensor, onnx_coord, onnx_label, orig_size
)
binary_mask = (mask > 0).astype('uint8')
mask_rgb = mask_to_rgb(binary_mask)
overlay_image = overlay_transparent_mask(image_bgr, mask_rgb, binary_mask)
final_image_with_point = overlay_point(overlay_image, input_point, (0, 0, 255), 10, -1)

try:
  from google.colab.patches import cv2_imshow
  cv2_imshow(final_image_with_point)
except:
  cv2.imshow(final_image_with_point)
