# Exporting Gr00tPolicy to ONNX

This notebook demonstrates how to export the core neural network of a `Gr00tPolicy` to the ONNX (Open Neural Network Exchange) format using the `to_onnx` method.

**Why ONNX?**
*   **Interoperability:** Run your model in different frameworks and runtimes (ONNX Runtime, TensorRT, OpenVINO).
*   **Optimization:** ONNX runtimes often provide hardware-specific optimizations for faster inference.
*   **Deployment:** Simplifies deploying models to various platforms (servers, edge devices).

**Key Considerations:**
*   **Preprocessing/Postprocessing:** The `to_onnx` method exports *only* the underlying `torch.nn.Module` (`policy.model`). The data normalization (`policy.apply_transforms`) and denormalization (`policy.unapply_transforms`) steps are **NOT** included in the ONNX graph. You **MUST** reimplement this logic in your deployment environment.
*   **Input:** The `to_onnx` method requires an example input dictionary that has already been **normalized** using `policy.apply_transforms`.
*   **Output:** The ONNX model will output **normalized** actions.

## 1. Setup and Policy Initialization

In [1]:
import torch
import numpy as np
import os

# --- Configuration (Replace with your actual paths and settings) ---
# Example using a Hugging Face model ID
MODEL_ID = "nvidia/gr00t-1" # Or use a local path: "/path/to/your/gr00t/checkpoint"
EMBODIMENT_TAG = "franka_panda" # Specify the embodiment your model uses
OUTPUT_ONNX_PATH = "gr00t_policy_model.onnx"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Mock/Placeholder Imports (Replace with your actual classes) ---
# These classes depend on your specific gr00t setup.
# You'll need to import them correctly from your project structure.
try:
    from gr00t.model.policy import Gr00tPolicy
    from gr00t.data.dataset import ModalityConfig # Placeholder
    from gr00t.data.transform.base import ComposedModalityTransform # Placeholder
    # Add any other necessary imports for your specific ModalityConfig/Transform
    from gr00t.data.transform.state_action import StateActionToTensor, Normalizer # Example transforms
    from gr00t.data.transform import VideoTransform # Example transforms
except ImportError as e:
    print(f"Error importing gr00t components: {e}")
    print("Please ensure you have the gr00t library installed and accessible,")
    print("and replace placeholder imports with your actual class paths.")
    # Define dummy classes if imports fail, to allow notebook structure to load
    class ModalityConfig:
        def __init__(self, modality_type, delta_indices, **kwargs): pass
    class ComposedModalityTransform:
        def __init__(self, *args, **kwargs): pass
        def __call__(self, obs): return obs # Dummy implementation
        def unapply(self, act): return act # Dummy implementation
        def set_metadata(self, meta): pass
        def eval(self): pass
    class Gr00tPolicy:
        def __init__(self, *args, **kwargs): raise NotImplementedError("Dummy class, replace with actual import")

# --- Define Modality Configuration and Transforms (Crucial Step) ---
# This MUST match the configuration used during training and expected by the model.
# Replace this with your actual configuration.
modality_config = {
    # Example: Define expected modalities, horizons, etc.
    "video": ModalityConfig(modality_type="video", delta_indices=[-1, 0]), # Needs H, W, C etc.
    "state": ModalityConfig(modality_type="state", delta_indices=[-1, 0]), # Needs state keys
    "action": ModalityConfig(modality_type="action", delta_indices=list(range(16))), # Needs action keys, action_horizon
    # Add other modalities (e.g., text) if your model uses them
}

# Example: Define the sequence of transforms
# Replace with your actual transform pipeline
modality_transform = ComposedModalityTransform(
    # Example transforms - replace with yours
    # ImgTransform(...),
    # NormalizeStateTensor(...),
    # NormalizeActionTensor(...),
)

# --- Initialize the Policy ---
try:
    policy = Gr00tPolicy(
        model_path=MODEL_ID,
        embodiment_tag=EMBODIMENT_TAG,
        modality_config=modality_config,
        modality_transform=modality_transform,
        device=DEVICE,
    )
    print(f"Gr00tPolicy initialized successfully on device: {policy.device}")
    print(f"Model path resolved to: {policy.model_path}")
except Exception as e:
    print(f"Error initializing Gr00tPolicy: {e}")
    print("Please check your model path, embodiment tag, and configurations.")
    policy = None # Set policy to None if initialization fails

Error importing gr00t components: No module named 'gr00t'
Please ensure you have the gr00t library installed and accessible,
and replace placeholder imports with your actual class paths.
Error initializing Gr00tPolicy: Dummy class, replace with actual import
Please check your model path, embodiment tag, and configurations.


## 2. Prepare Example Observation Data

We need a sample observation dictionary with the correct structure and data types that the policy expects. The shapes should match your `modality_config` (especially the time dimension based on `delta_indices`).

In [None]:
# --- Create Dummy Observation Data ---
# Replace with realistic data or load a sample from your dataset.
# The shapes (especially time dimension T) must match policy requirements.

# Example: Assuming delta_indices=[-1, 0], so Time dimension T=2
T = 2
H, W, C = 224, 224, 3 # Example video dimensions
state_dim_arm = 7
state_dim_hand = 6

observations = {
    # Use keys defined in your modality_config
    'video.agentview_left': np.zeros((T, H, W, C), dtype=np.uint8),
    'video.agentview_right': np.zeros((T, H, W, C), dtype=np.uint8),
    'state.joint_positions': np.zeros((T, state_dim_arm * 2), dtype=np.float32),
    'state.gripper_positions': np.zeros((T, state_dim_hand * 2), dtype=np.float32),
    # Add other observation keys as needed by your model/config
    # 'language.instruction': ['pick up the cube'] # Example text input
}

# Helper function from policy.py (or copy it here)
def unsqueeze_dict_values(data: dict) -> dict:
    unsqueezed_data = {}
    for k, v in data.items():
        if isinstance(v, np.ndarray):
            unsqueezed_data[k] = np.expand_dims(v, axis=0)
        elif isinstance(v, torch.Tensor):
            unsqueezed_data[k] = v.unsqueeze(0)
        else:
            # Handle non-tensor data like lists of strings (e.g., language)
            if isinstance(v, list):
                 # Assuming batch size 1 for non-tensor lists
                 unsqueezed_data[k] = v
            else:
                 unsqueezed_data[k] = v # Or handle other types as needed
    return unsqueezed_data

# Add batch dimension (B=1) as policy methods expect it
observations_batched = unsqueeze_dict_values(observations)

print("Sample observation keys:", list(observations_batched.keys()))
print("Example shape (video):", observations_batched['video.agentview_left'].shape)
print("Example shape (state):", observations_batched['state.joint_positions'].shape)

## 3. Normalize Input Data

This is a **critical step**. The `to_onnx` method requires the input data to be normalized using the policy's `apply_transforms` method. The resulting tensors must also be moved to the same device as the policy model.

In [None]:
if policy is not None:
    # Apply the same preprocessing transforms used during training/inference
    normalized_input = policy.apply_transforms(observations_batched)

    # Ensure all tensors in the normalized input are on the correct device
    example_normalized_input_device = {}
    for key, value in normalized_input.items():
        if isinstance(value, torch.Tensor):
            example_normalized_input_device[key] = value.to(policy.device)
        else:
            # Handle non-tensor data if necessary (e.g., language embeddings might already be tensors)
            example_normalized_input_device[key] = value

    print("Normalized input prepared and moved to device:", policy.device)
    print("Normalized input keys:", list(example_normalized_input_device.keys()))
    # Print shape of one tensor to verify
    first_tensor_key = next(k for k, v in example_normalized_input_device.items() if isinstance(v, torch.Tensor))
    print(f"Example normalized shape ({first_tensor_key}):", example_normalized_input_device[first_tensor_key].shape)
else:
    print("Policy not initialized. Skipping normalization.")
    example_normalized_input_device = None

## 4. Define ONNX Export Parameters

We need to define:
*   `input_names`: List of names corresponding to the keys in `example_normalized_input_device`.
*   `output_names`: List of names for the output tensors produced by the ONNX model. This depends on the output structure of `policy.model.get_action`. Often, it's a single tensor (e.g., `['action_pred']`).
*   `dynamic_axes` (Optional): Allows specifying which dimensions (like batch size) can vary at runtime.

In [None]:
if example_normalized_input_device is not None:
    # Input names should match the keys of the normalized input dictionary
    input_names = list(example_normalized_input_device.keys())

    # Output names depend on what policy.model.get_action returns.
    # Check the Gr00tPolicy._get_action_from_normalized_input method.
    # It likely returns a dict like {'action_pred': tensor}, so we use ['action_pred'].
    # Adjust if your model's output structure is different.
    output_names = ["action_pred"]

    # Optional: Define dynamic axes for variable batch size
    # This allows the exported ONNX model to handle inputs with different batch sizes.
    dynamic_axes = {}
    for name in input_names:
        # Assuming the first dimension is batch for all tensor inputs
        if isinstance(example_normalized_input_device[name], torch.Tensor):
             dynamic_axes[name] = {0: 'batch_size'}
        # Add handling for other dynamic axes if needed (e.g., sequence length)

    for name in output_names:
        # Assuming the first dimension is batch for all tensor outputs
        dynamic_axes[name] = {0: 'batch_size'}
        # Add handling for other dynamic axes if needed (e.g., action horizon)

    print("Input Names:", input_names)
    print("Output Names:", output_names)
    print("Dynamic Axes:", dynamic_axes)
else:
    print("Normalized input not available. Skipping parameter definition.")

## 5. Export to ONNX

In [None]:
if policy is not None and example_normalized_input_device is not None:
    print(f"Attempting to export model to: {OUTPUT_ONNX_PATH}")
    try:
        policy.to_onnx(
            output_path=OUTPUT_ONNX_PATH,
            example_normalized_input=example_normalized_input_device,
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            opset_version=17, # Try other versions (e.g., 14, 16) if export fails
            verbose=True # Set to False for less output
        )
        print("\nONNX export successful!")
        print(f"Model saved to {os.path.abspath(OUTPUT_ONNX_PATH)}")
        print("\nReminder: The exported model expects NORMALIZED inputs and produces NORMALIZED outputs.")
        print("You need to implement pre-processing (normalization) and post-processing (denormalization) separately in your deployment environment.")
        onnx_export_successful = True
    except Exception as e:
        print(f"\nONNX export failed: {e}")
        # Error message already printed by policy.to_onnx
        onnx_export_successful = False
else:
    print("Policy or normalized input not ready. Skipping ONNX export.")
    onnx_export_successful = False

## 6. Verification (Optional) - Using ONNX Runtime

Let's load the exported ONNX model and run inference using `onnxruntime` to check if the output matches the PyTorch model's output (before denormalization).

In [None]:
if onnx_export_successful:
    try:
        import onnxruntime as ort
        import onnx
    except ImportError:
        print("Please install onnx and onnxruntime: pip install onnx onnxruntime")
        ort = None

    if ort:
        try:
            print(f"\nLoading ONNX model from {OUTPUT_ONNX_PATH}...")
            # Check model validity
            onnx_model = onnx.load(OUTPUT_ONNX_PATH)
            onnx.checker.check_model(onnx_model)
            print("ONNX model check passed.")

            # Create ONNX Runtime session
            # Specify providers based on your hardware (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider'])
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if DEVICE == 'cuda' else ['CPUExecutionProvider']
            available_providers = ort.get_available_providers()
            valid_providers = [p for p in providers if p in available_providers]
            if not valid_providers:
                print(f"Warning: None of the preferred providers {providers} are available. Using default.")
                valid_providers = None # Let ORT decide
            
            print(f"Creating ONNX Runtime session with providers: {valid_providers}")
            ort_session = ort.InferenceSession(OUTPUT_ONNX_PATH, providers=valid_providers)
            print("ONNX Runtime session created.")

            # Prepare input for ONNX Runtime (needs numpy arrays)
            ort_inputs = {}
            for key, tensor in example_normalized_input_device.items():
                if isinstance(tensor, torch.Tensor):
                     ort_inputs[key] = tensor.cpu().numpy()
                else:
                     # Handle non-tensor inputs if necessary (e.g. language might need specific format)
                     ort_inputs[key] = tensor # Assuming it's already in a suitable format (like list of strings)
                     # If model expects embeddings, they should be in example_normalized_input_device

            # Run inference
            print("Running inference with ONNX Runtime...")
            ort_outputs = ort_session.run(output_names, ort_inputs)
            print("ONNX Runtime inference complete.")

            # Compare with PyTorch output (before denormalization)
            with torch.no_grad():
                 pytorch_normalized_action = policy._get_action_from_normalized_input(example_normalized_input_device)

            # Assuming the first output is the action prediction
            onnx_output_action = ort_outputs[0]
            pytorch_output_action_np = pytorch_normalized_action.cpu().numpy()

            # Check shapes
            print(f"PyTorch normalized output shape: {pytorch_output_action_np.shape}")
            print(f"ONNX normalized output shape:    {onnx_output_action.shape}")

            # Check values (using tolerance due to potential floating point differences)
            if np.allclose(pytorch_output_action_np, onnx_output_action, rtol=1e-3, atol=1e-4):
                print("\nVerification SUCCESS: ONNX Runtime output closely matches PyTorch output.")
            else:
                diff = np.abs(pytorch_output_action_np - onnx_output_action)
                print("\nVerification WARNING: ONNX Runtime output differs from PyTorch output.")
                print(f"  Max absolute difference: {np.max(diff)}")
                print(f"  Mean absolute difference: {np.mean(diff)}")
                print("  This might be due to precision differences (e.g., BF16 vs FP32) or export issues.")

        except Exception as e:
            print(f"\nError during ONNX Runtime verification: {e}")
else:
    print("\nSkipping ONNX verification because export did not succeed or ort is not available.")

## 7. Next Steps: Using the ONNX Model

1.  **Integrate into your application:** Load the `gr00t_policy_model.onnx` file using your chosen ONNX runtime (ONNX Runtime, TensorRT, etc.).
2.  **Implement Preprocessing:** Before feeding observations to the ONNX model, apply the exact same normalization steps as `policy.apply_transforms` does.
3.  **Run Inference:** Pass the normalized input tensors to the ONNX model.
4.  **Implement Postprocessing:** Take the normalized output tensors from the ONNX model and apply the inverse transformations (denormalization) equivalent to `policy.unapply_transforms` to get the final, usable actions.

In [1]:
import torch
from PIL import Image
import requests
from io import BytesIO

# Assuming eagle2_hg_model directory is accessible
# Adjust model_name path if necessary
from gr00t.model.backbone.eagle_backbone import EagleBackbone
from gr00t.model.backbone.eagle2_hg_model.inference_eagle_repo import EagleProcessor, ModelSpecificValues

# --- Configuration ---
# Adjust paths and settings as needed for your environment
MODEL_PATH = "/teamspace/studios/this_studio/Isaac-GR00T/gr00t/model/backbone/eagle2_hg_model" # Or where you have the model files
PROCESSOR_CFG = {
    "model_path": MODEL_PATH,
    "max_input_tiles": 1, # Example value, adjust based on config/needs
    "model_spec": {
        "template": "qwen2-chat", # From config.json
        "num_image_token": 64    # From config.json (adjust if scale_image_resolution != 1)
    }
}
IMAGE_URL = "https://garden.spoonflower.com/c/14186636/i/m/DYzyp2VREexAT8IEVolsqQn52vqoyzlo2P8AI4IZqg0QOjVXtA/14186636-imgonline-com-ua-resize-zlsdjrcdmm1-by-ameetlad.jpg" # Replace with an actual image URL
QUESTION = "Describe this image."
SYSTEM_PROMPT = "You are a helpful assistant." # Or use the default from the processor

# --- Initialization ---
# Ensure you are running on CPU (PyTorch defaults if no GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Instantiate processor and backbone
# Note: EagleBackbone constructor might need adjustments based on your exact needs
# (e.g., select_layer, tune_llm, tune_visual, projector_dim etc.)
processor = EagleProcessor(
    model_path=PROCESSOR_CFG["model_path"],
    max_input_tiles=PROCESSOR_CFG["max_input_tiles"],
    model_spec=ModelSpecificValues(**PROCESSOR_CFG["model_spec"]),
    use_local_eagle_hg_model=False # Set to False if using a custom path
)

backbone = EagleBackbone(
    model_name=MODEL_PATH,
    processor_cfg=PROCESSOR_CFG,
    use_local_eagle_hg_model=False # Set to False if using a custom path
).to(device)
backbone.eval() # Set to evaluation mode

# --- Prepare Input ---
# Load image (Example using requests)
try:
    response = requests.get(IMAGE_URL)
    response.raise_for_status() # Raise an exception for bad status codes
    image = Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
    print(f"Error downloading image: {e}")
    # Handle error appropriately, maybe exit or use a placeholder
    image = None # Or some default image

if image:
    # Prepare input using the processor
    # The processor expects a specific dictionary format
    input_params = {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": QUESTION, "image": [{"pil_image": image}]} # Pass PIL image directly if loaded
            # Or use other formats supported by load_image in inference_eagle_repo.py
            # {"role": "user", "content": QUESTION, "image": [{"url": IMAGE_URL}]}
        ],
        # Add other params if needed, e.g., "video_frame_num"
    }

    vl_input_data = processor.prepare_input(input_params)

    # Move tensors to the correct device
    vl_input = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in vl_input_data.items()}
    vl_input = backbone.prepare_input(vl_input) # Wrap in BatchFeature

    # --- Run Forward Pass ---
    with torch.no_grad():
        try:
            output = backbone(vl_input)
            print("Backbone output features shape:", output["backbone_features"].shape)
            # You can now use output["backbone_features"] and output["backbone_attention_mask"]
        except Exception as e:
            print(f"Error during forward pass: {e}")
            # Add more detailed error handling/debugging if needed
else:
    print("Cannot proceed without a valid image.")



  check_for_updates()


Using device: cpu
Eagle2ChatConfig {
  "_commit_hash": null,
  "_name_or_path": "/teamspace/studios/this_studio/Isaac-GR00T/gr00t/model/backbone/eagle2_hg_model",
  "architectures": [
    "Eagle2ChatModel"
  ],
  "auto_map": {
    "AutoConfig": "configuration_eagle_chat.Eagle2ChatConfig",
    "AutoModel": "modeling_eagle_chat.Eagle2ChatModel",
    "AutoModelForCausalLM": "modeling_eagle_chat.Eagle2ChatModel"
  },
  "downsample_ratio": 0.5,
  "dynamic_image_size": true,
  "force_image_size": 224,
  "keep_aspect_ratio": false,
  "llm_config": {
    "_name_or_path": "./pretrained/SmolLM2-1_7B-Instruct",
    "add_cross_attention": false,
    "architectures": [
      "LlamaForCausalLM"
    ],
    "attention_bias": false,
    "attention_dropout": 0.0,
    "auto_map": {
      "AutoConfig": "configuration_llama.LlamaConfig",
      "AutoModel": "modeling_llama.LlamaModel",
      "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM"
    },
    "bad_words_ids": null,
    "begin_suppress_token

ValueError: Invalid image: {'pil_image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x73CBEC7DF250>}

In [3]:
import os
import gr00t

DEFAULT_EAGLE_MODEL_NAME = os.path.join(
    os.path.dirname(gr00t.__file__), "model", "backbone", "eagle2_hg_model"
)

In [4]:
DEFAULT_EAGLE_MODEL_NAME

'/teamspace/studios/this_studio/Isaac-GR00T/gr00t/model/backbone/eagle2_hg_model'

In [5]:
from transformers import AutoConfig, AutoModel

config = AutoConfig.from_pretrained(DEFAULT_EAGLE_MODEL_NAME, trust_remote_code=True)
config._attn_implementation = "eager"

In [9]:
model = AutoModel.from_config(
    config,
    trust_remote_code=True,
    attn_implementation = "eager"

)
print(model.config)
model.neftune_alpha = None

ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.