# Exporting and saving ML models

Exporting and saving ML models for reinforcement learning is a crucial step in the model development process, and allows you to preserve and reuse the state of a model after training, and deploy it in different environments. This guide explains how to export and save ML models, focusing on various aspects and formats comonly used.

### Saving RL Models
- Saving preserves the model's architecture, trained weights and often associated configuration information (like hyperparameters) so that you can reuse the model
- Intended for future use within the same framework or closely related environments where you started training

### Common Formats for Saving Models

PyTorch (.pth or .pt): Saves either the entire model or just the state dictionary, including the weights and biases but not the architecture.

Safetensors: A more safer way of pickling tensors and weights that is still fast.


### Creating a dummy environment

Some of the parameters in the RL models are closely interrelated with an environment, so we will load MountainCar-v0 to initialise these parameters

In [1]:
!pip install torchrl==0.7.0 gymnasium==0.29 tqdm matplotlib av tensordict==0.7.2

Collecting torchrl==0.7.0
  Downloading torchrl-0.7.0-cp311-cp311-manylinux1_x86_64.whl.metadata (39 kB)
Collecting gymnasium==0.29
  Downloading gymnasium-0.29.0-py3-none-any.whl.metadata (10 kB)
Collecting av
  Downloading av-14.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
Collecting tensordict==0.7.2
  Downloading tensordict-0.7.2-cp311-cp311-manylinux1_x86_64.whl.metadata (9.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.6.0->torchrl==0.7.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.6.0->torchrl==0.7.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.6.0->torchrl==0.7.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70

In [11]:
from torchrl.envs import (
    Compose, DoubleToFloat,
    ObservationNorm, StepCounter,
    TransformedEnv, set_exploration_type,
)
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from torch.distributions import Categorical
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.envs import GymWrapper
import gymnasium as gym
import torch
base_env = gym.make("MountainCar-v0", render_mode="rgb_array")
env = GymWrapper(
    gym.make("MountainCar-v0", render_mode="rgb_array"), categorical_action_encoding=  True, device = "cpu"
)

env = TransformedEnv(env, Compose(
    DoubleToFloat(),
    StepCounter(),
))

print(env.action_spec)


Categorical(
    shape=torch.Size([]),
    space=CategoricalBox(n=3),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)


### Defining a simple model for RL

In [12]:
num_cells = 64

# Simple Actor-Critic Setup

# You can skip these if you want, these are the underlying neural networks.
# Since we are using a Discrete policy, we need to use a Softmax to transform the outputs into action probabilities.
actor_net = nn.Sequential(
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(3),
    nn.Softmax(dim = -1)
)


value_net = nn.Sequential(
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(num_cells),
    nn.Tanh(),
    nn.LazyLinear(1),
)


# Actor Module
policy_module = ProbabilisticActor(
    module = TensorDictModule(
        actor_net, in_keys=["observation"], out_keys=["logits"]
    ),
    spec=env.action_spec,
    in_keys=["logits"],
    distribution_class=Categorical,
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)

# Critic Module
value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)


In [13]:
# Saving models in Pytorch
torch.save(policy_module, 'policy_module.pth')
torch.save(value_module, 'value_module.pth')

#Saving model state dictionaries
#Make sure you run a dummy pass through them to initialise the values
dummy_td = env.reset()
dummy_observation = dummy_td["observation"].unsqueeze(0)
# Run a dummy forward pass through the actor network
# If your actor network expects the observation in a dict form, wrap it accordingly.
_ = actor_net(dummy_observation)

# Similarly, run a dummy forward pass through the value (critic) network
_ = value_net(dummy_observation)

torch.save(policy_module.state_dict(), 'loss_module_state_dict.pth')
torch.save(value_module.state_dict(), 'loss_module_state_dict.pth')





### Exporting Models
- Exporting a model converts it into a representaiton suitable for deployment in production environments or for use across different frameworks
- Involves optimizations or format changes for better inference speed and compatability

In [14]:
!pip install onnx onnxruntime onnxscript



In [15]:
from pprint import pprint
### Exporting normally through PyTorch routines

policy_transform = TensorDictSequential(
    env.transform[: -1], #Last transform is a step counter which we don't need
    policy_module.requires_grad_(
        False
    ), # Using the explorative version of the policy for teaching purposes
)

fake_td = env.base_env.fake_tensordict()
obs = fake_td['observation']
with set_exploration_type("DETERMINISTIC"):
    exported = torch.export.export(
    policy_transform.select_out_keys("action"),
    args=(),
    kwargs={'observation':obs},
    strict = False
  )

print("Deterministic Policy")
exported.graph_module.print_readable()

### We can run outputs through the exported module as well
output = exported.module()(observation=obs)
print("Exported Module Output")
print(output)

Deterministic Policy
class GraphModule(torch.nn.Module):
    def forward(self, p_module_1_module_0_module_0_weight: "f32[64, 2]", p_module_1_module_0_module_0_bias: "f32[64]", p_module_1_module_0_module_2_weight: "f32[64, 64]", p_module_1_module_0_module_2_bias: "f32[64]", p_module_1_module_0_module_4_weight: "f32[64, 64]", p_module_1_module_0_module_4_bias: "f32[64]", p_module_1_module_0_module_6_weight: "f32[3, 64]", p_module_1_module_0_module_6_bias: "f32[3]", kwargs_observation: "f32[2]"):
         # File: /usr/local/lib/python3.11/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
        linear: "f32[64]" = torch.ops.aten.linear.default(kwargs_observation, p_module_1_module_0_module_0_weight, p_module_1_module_0_module_0_bias);  kwargs_observation = p_module_1_module_0_module_0_weight = p_module_1_module_0_module_0_bias = None
        
         # File: /usr/local/lib/python3.11/dist-packages/torch/nn/modules/activation.py

    module=ModuleList(
      (0): TensorDictModule(
          module=Sequential(
            (0): Linear(in_features=2, out_features=64, bias=True)
            (1): Tanh()
            (2): Linear(in_features=64, out_features=64, bias=True)
            (3): Tanh()
            (4): Linear(in_features=64, out_features=64, bias=True)
            (5): Tanh()
            (6): Linear(in_features=64, out_features=3, bias=True)
            (7): Softmax(dim=-1)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits']). This might be because the module was not properly registered as a submodule, which is not good practice. We will trace through the module without recording stack information.


In [None]:
import torch.onnx
import onnxruntime


### Using onnx to export
with set_exploration_type("DETERMINISTIC"):
  obs = fake_td['observation']
  onnx_policy = torch.onnx.dynamo_export(policy_transform, observation = obs)


### Save ONNX model
onnx_file_path = "policy.onnx"
onnx_policy.save(onnx_file_path)




### Loading a Model

In [8]:
adv_model = torch.load("policy_module.pth", weights_only = False)
loss_model = torch.load("value_module.pth", weights_only= False)

In [9]:
### Loading via ONNX
ort_session = onnxruntime.InferenceSession(
    onnx_file_path, providers=["CPUExecutionProvider"]
)

onnxruntime_input = {ort_session.get_inputs()[0].name: obs.numpy()}
onnx_policy = ort_session.run(None, onnxruntime_input)

In [10]:
# Running a rollout with ONNX:
from torchrl._utils import timeit
import numpy as np

def onnx_policy(screen_obs: np.ndarray) -> int:
    onnxruntime_input = {ort_session.get_inputs()[0].name: obs}
    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
    action = int(onnxruntime_outputs[0])
    return action


with timeit("ONNX rollout"):
    num_steps = 1000
    td = base_env.reset()
    print(td)
    for _ in range(num_steps):
        obs = td[0]
        action = onnx_policy(obs)
        reward = base_env.step(action)

with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
    env.rollout(num_steps, policy_module)

print(timeit.print())

(array([-0.5069718,  0.       ], dtype=float32), {})


2025-05-06 05:17:40,364 [torchrl][INFO] ONNX rollout took 210.5 msec (total = 0.21054577827453613 sec)
2025-05-06 05:17:40,367 [torchrl][INFO] TorchRL version took 672.4 msec (total = 0.6723911762237549 sec)


ONNX rollout took 210.5 msec (total = 0.21054577827453613 sec)
TorchRL version took 672.4 msec (total = 0.6723911762237549 sec)
