### Step 0 : Init RoboFeeder Env
This cell sets up the environment for the RoboFeeder simulation by importing necessary modules and configuring the Python path.

In [None]:
import os 
dir = os.getcwd()
if 'examples' in dir:
    os.chdir(os.getcwd().split('examples')[0])
else:
    print("please set the working directory to the root of the gym4ReaL repository")

# check if the current working directory is the root of the gym4ReaL repository
os.getcwd()

'/home/edge/Desktop/gym4ReaL'

### Step 1 : Import Required Modules
This cell imports the necessary modules and updates the system path to include the gym4ReaL repository. It also imports the robot simulator and matplotlib for visualization purposes.

In [2]:
import sys
sys.path.append(os.getcwd())  # <-- path to the *parent* of gym4real

import torch as th
from stable_baselines3 import PPO
import onnx
import onnxruntime as rt
import numpy as np
import gymnasium as gym

### Step 2 : Define ONNX-Compatible Policy Wrapper

This cell defines the `OnnxablePolicyPyTorch2` class, which wraps a PyTorch policy to make it compatible with ONNX export. The `forward` method ensures that observations are processed correctly and that the policy can be exported in a deterministic or stochastic manner as needed.


In [3]:
class OnnxablePolicyPyTorch2(th.nn.Module):
    def __init__(self, policy):
        super().__init__()
        self.policy = policy

    def forward(self, observation):
        # NOTE: Preprocessing is included, the only thing you need to do
        # is transpose the images if needed so that they are channel first
        # use deterministic=False if you want to export the stochastic policy
        return self.policy(observation, deterministic=False)

### Step 3 : Load Pretrained PPO Model

This cell loads a pretrained PPO model from the specified path and assigns it to the variable `model`.

In [None]:
ppo_model_path = "ppo_5k.zip"  # Change this path as needed
device = "cpu"  # Change to "cuda" if GPU is available and desired

model = PPO.load(ppo_model_path, device=device)

### Step 4 : Export PyTorch Policy to ONNX

This cell wraps the PPO model's policy with the `OnnxablePolicyPyTorch2` class and exports it to the ONNX format using a dummy input. The exported ONNX model can be used for inference in environments that support ONNX.


In [12]:
onnx_pytorch2 = OnnxablePolicyPyTorch2(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
model_output_path = "robofeeder_planning.onnx"

th.onnx.export(
    onnx_pytorch2,
    dummy_input,
    model_output_path,
    opset_version=17,  # neeed a "updated" version of onnx
    input_names=["input"],
    verbose=0
)

### Step 5 : Load and Validate the ONNX Model

This cell loads the exported ONNX model, checks its validity, and creates an inference session using ONNX Runtime with the specified providers.

In [None]:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
onnx_model = onnx.load(model_output_path)
onnx.checker.check_model(onnx_model)

ort_sess = rt.InferenceSession(model_output_path,providers=providers)