**This is example of how to trace model with jit and export it to the onnx**

In [None]:
from rl_games.torch_runner import Runner
import ray
import yaml
import torch
import matplotlib.pyplot as plt
import gym
from IPython import display
import numpy as np
import onnx
import onnxruntime as ort
%matplotlib inline

In [None]:
ray.init(object_store_memory=1024*1024*1000)

In [None]:
config_name = 'rl_games/configs/ppo_cartpole.yaml'

In [None]:
with open(config_name, 'r') as stream:
    config = yaml.safe_load(stream)
    config['params']['config']['full_experiment_name'] = 'cartpole_onnx'
runner = Runner()
runner.load(config)
runner.run({
    'train': True,
})

In [None]:
agent = runner.create_player()
agent.restore('runs/cartpole_onnx/nn/cartpole_vel_info.pth')

import rl_games.algos_torch.flatten as flatten
inputs = {
    'obs' : torch.zeros((1,) + agent.obs_shape).to(agent.device),
    'rnn_states' : agent.states
}
with torch.no_grad():
    adapter = flatten.TracingAdapter(agent.model.a2c_network, inputs,allow_non_tensor=True)
    traced = torch.jit.trace(adapter, adapter.flattened_inputs,check_trace=False)
    flattened_outputs = traced(*adapter.flattened_inputs)
    print(flattened_outputs)
    
torch.onnx.export(traced, *adapter.flattened_inputs, "cartpole.onnx", verbose=True, input_names=['obs'], output_names=['logits', 'value'])

onnx_model = onnx.load("cartpole.onnx")

# Check that the model is well formed
onnx.checker.check_model(onnx_model)

In [None]:
ort_model = ort.InferenceSession("cartpole.onnx")

outputs = ort_model.run(
    None,
    {"obs": np.zeros((1, 4)).astype(np.float32)},
)
print(outputs)

In [None]:
is_done = False
env = agent.env
obs = env.reset()
#prev_screen = env.render(mode='rgb_array')
#plt.imshow(prev_screen)
total_reward = 0
num_steps = 0
while not is_done:
    outputs = ort_model.run(None, {"obs": np.expand_dims(obs, axis=0).astype(np.float32)},)
    action = np.argmax(outputs[0])
    obs, reward, done, info = env.step(action)
    total_reward += reward
    num_steps += 1
    is_done = done
    screen = env.render(mode='rgb_array')
    #plt.imshow(screen)
    #display.display(plt.gcf())    
    #display.clear_output(wait=True)
print(total_reward, num_steps)
#ipythondisplay.clear_output(wait=True)