New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to make ONNX model of the network with CNN-MLP Multi Input policy? #1349
Comments
Hello,
When using CNN, the features extractor is shared by default, so you need to do For more details, you can read: stable-baselines3/stable_baselines3/common/policies.py Lines 628 to 651 in 085bdd5
and stable-baselines3/stable_baselines3/common/policies.py Lines 590 to 612 in 085bdd5
|
Btw, if you are using PyTorch 2.0 (preview) and ONNX opset 14+, this seems to work (for deterministic and stochastic policy): Tagging @Gregwar as the solution is quite simple and works with both MLP/CNN, and with PPO/SAC (and probably the rest =)), and includes pre-processing too! import torch as th
from stable_baselines3 import PPO
class OnnxablePolicyPyTorch2(th.nn.Module):
def __init__(self, policy):
super().__init__()
self.policy = policy
def forward(self, observation):
# NOTE: Preprocessing is included, the only thing you need to do
# is transpose the images if needed so that they are channel first
# use deterministic=False if you want to export the stochastic policy
return self.policy(observation, deterministic=True)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnx_pytorch2 = OnnxablePolicyPyTorch2(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnx_pytorch2,
dummy_input,
"my_ppo_model.onnx",
opset_version=17,
input_names=["input"],
)
##### Load and test with onnx
import onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True)) |
Including pre processing in ONNX itself seems like the good way to go indeed! Another thing that would be a nice addition is a way to export the value functions as (we use them in inference as well so that we can compare different policies online for instance), so far we use some similar custom script specific to TD3 but I think it could be extended to other policies and documented |
this works for me too (PPO with CNN policy). It would be nice if this snippet could be added to the current documentation: https://stable-baselines3.readthedocs.io/en/master/guide/export.htmls |
Feel free to submit a PR =) |
❓ Question
I am interested in using stable-baselines to train an agent, and then export it through ONNX.
So, I made my code as following:
and
However, I got a error,
line 16, in forward
action_hidden, value_hidden = self.extractor(observation)
ValueError: not enough values to unpack (expected 2, got 1)
How can I fix this?
Checklist
The text was updated successfully, but these errors were encountered: