Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to make ONNX model of the network with CNN-MLP Multi Input policy? #1349

Closed
4 tasks done
DoHyun-Chun opened this issue Feb 26, 2023 · 5 comments · Fixed by #1816
Closed
4 tasks done

How to make ONNX model of the network with CNN-MLP Multi Input policy? #1349

DoHyun-Chun opened this issue Feb 26, 2023 · 5 comments · Fixed by #1816
Labels
documentation Improvements or additions to documentation help wanted Help from contributors is welcomed question Further information is requested

Comments

@DoHyun-Chun
Copy link

❓ Question

I am interested in using stable-baselines to train an agent, and then export it through ONNX.

So, I made my code as following:

model = PPO.load('Normal_best.zip', env=env)
onnxable_model = OnnxablePolicy.OnnxablePolicy(model.policy.features_extractor, model.policy.action_net, model.policy.value_net, model)
dummy_inputs = {k: th.randn(1, *obs.shape) for k, obs in model.observation_space.items()}
th.onnx.export(
    onnxable_model,
    (dummy_inputs, {}),
    "my_multiPolicy.onnx",
    opset_version=9,
    input_names=["input"],
)

and

class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net, model):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net
        self.model = model

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        return self.action_net(action_hidden), self.value_net(value_hidden)

However, I got a error,
line 16, in forward
action_hidden, value_hidden = self.extractor(observation)
ValueError: not enough values to unpack (expected 2, got 1)

How can I fix this?

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • If code there is, it is minimal and working
  • If code there is, it is formatted using the markdown code blocks for both code and stack traces.
@DoHyun-Chun DoHyun-Chun added the question Further information is requested label Feb 26, 2023
@araffin
Copy link
Member

araffin commented Feb 27, 2023

Hello,
the answer is both in the error and in the code.

ValueError: not enough values to unpack (expected 2, got 1)

When using CNN, the features extractor is shared by default, so you need to do action_hidden = value_hidden = self.extractor(observation) as a single value and not a tuple is returned.
And if you are using discrete actions, you will need to take the argmax (or sample) as action_net() will return the logits (it does return the action in case of continuous actions).

For more details, you can read:

def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
raise ValueError("Invalid action distribution")

and

def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob

@araffin
Copy link
Member

araffin commented Feb 27, 2023

Btw, if you are using PyTorch 2.0 (preview) and ONNX opset 14+, this seems to work (for deterministic and stochastic policy):

Tagging @Gregwar as the solution is quite simple and works with both MLP/CNN, and with PPO/SAC (and probably the rest =)), and includes pre-processing too!

import torch as th

from stable_baselines3 import PPO


class OnnxablePolicyPyTorch2(th.nn.Module):
    def __init__(self, policy):
        super().__init__()
        self.policy = policy

    def forward(self, observation):
        # NOTE: Preprocessing is included, the only thing you need to do
        # is transpose the images if needed so that they are channel first
        # use deterministic=False if you want to export the stochastic policy
        return self.policy(observation, deterministic=True)


# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")

onnx_pytorch2 = OnnxablePolicyPyTorch2(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
    onnx_pytorch2,
    dummy_input,
    "my_ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
actions, values, log_prob = ort_sess.run(None, {"input": observation})

print(actions, values, log_prob)

with th.no_grad():
    print(model.policy(th.as_tensor(observation), deterministic=True))

@araffin araffin added the documentation Improvements or additions to documentation label Feb 27, 2023
@Gregwar
Copy link
Contributor

Gregwar commented Feb 27, 2023

Including pre processing in ONNX itself seems like the good way to go indeed!

Another thing that would be a nice addition is a way to export the value functions as (we use them in inference as well so that we can compare different policies online for instance), so far we use some similar custom script specific to TD3 but I think it could be extended to other policies and documented

@GppCalcagno
Copy link

Btw, if you are using PyTorch 2.0 (preview) and ONNX opset 14+, this seems to work (for deterministic and stochastic policy):

Tagging @Gregwar as the solution is quite simple and works with both MLP/CNN, and with PPO/SAC (and probably the rest =)), and includes pre-processing too!

import torch as th

from stable_baselines3 import PPO


class OnnxablePolicyPyTorch2(th.nn.Module):
    def __init__(self, policy):
        super().__init__()
        self.policy = policy

    def forward(self, observation):
        # NOTE: Preprocessing is included, the only thing you need to do
        # is transpose the images if needed so that they are channel first
        # use deterministic=False if you want to export the stochastic policy
        return self.policy(observation, deterministic=True)


# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")

onnx_pytorch2 = OnnxablePolicyPyTorch2(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
    onnx_pytorch2,
    dummy_input,
    "my_ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

this works for me too (PPO with CNN policy). It would be nice if this snippet could be added to the current documentation: https://stable-baselines3.readthedocs.io/en/master/guide/export.htmls

@araffin
Copy link
Member

araffin commented Jul 20, 2023

this works for me too (PPO with CNN policy). It would be nice if this snippet could be added to the current documentation: https://stable-baselines3.readthedocs.io/en/master/guide/export.htmls

Feel free to submit a PR =)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation help wanted Help from contributors is welcomed question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants