Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to export trained model as a .pt (pytorch ) or ONNX model. #227

Open
manaspalaparthi opened this issue Feb 27, 2024 · 2 comments
Open

Comments

@manaspalaparthi
Copy link

How to export trained model as a .pt (pytorch ) or ONNX model.

I have fully trained my model and want to deploy the model into the Unity ML agents Env. I have to export the trained model either in Pytorch or ONNX.

I could only see one option "algo.render()" in the documentation.

@manaspalaparthi manaspalaparthi changed the title How export trained model as a .pt (pytorch ) or ONNX model. How to export trained model as a .pt (pytorch ) or ONNX model. Feb 27, 2024
@Aequatio-Space
Copy link

Aequatio-Space commented Feb 28, 2024

Although I do not know how Ray can do that directly, I tried to unwrap a Ray checkpoint and figured out its structure.
First, load the raw checkpoint with pickle.load, you will get a dictionary instance, whose value for key 'worker' is a bytes instance that contains the model weights. Use pickle.loads to get the worker status dictionary. select key 'state' and then 'weight', which will be the raw parameters for the network. You may manually pack them into a .pt object.

@Morphlng
Copy link

Morphlng commented Feb 28, 2024

RLlib's Policy class has the function export_model, which is used for exporting raw learning framework model with options to save as ONNX model.

So the problem falls back to how to load the checkpoint MARLlib saved. I've personally wrote a script to load the checkpoint + params.json. You can reuse the load_model function to retreive the policy, and then export it:

from eval import load_model

ckpt = load_model(
    {
        "model_path": "best_model/checkpoint",
        "params_path": "best_model/params.json",
    }
)

env = marl.make_env(environment_name=ckpt.env_name, map_name=ckpt.map_name)
env_instance, env_info = env

# Change the policy name accordingly
policy = ckpt.trainer.get_policy("shared_policy")
policy.export_model("/directoty/to/save")

PS: In case anybody want to know how to use the raw model:

model = policy.model
state = policy.get_initial_state()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_obs(env):
    obs = env.observation_space.sample()
    # Suppose observation is a dict. E.g.
    # obs = {
    # 	"action_mask": [0, 0, 1, 0],
    #	"obs": [1, 1, 4, 5, 1, 4],
    # }
    for key in obs:
        obs[key] = torch.from_numpy(np.array([obs[key]])).to(DEVICE)
    return obs


dummy_input = {
    "input_dict": {"obs": get_obs(env_instance)},
    "state": [torch.from_numpy(np.array(state)).to(DEVICE)],
    "seq_lens": np.array([1])
}

output = model(**dummy_input)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants