In [None]:
import functools
import jax
import os
import time
# jax.config.update('jax_platform_name', "cpu")

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo

import distutils.util
import os
import subprocess


import jax.numpy as jnp
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
import mediapy as media
import matplotlib.pyplot as plt

from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
import functools
from matplotlib import pyplot as plt

from datetime import datetime

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

from ambersim.envs.exo import Exo
#
env_name = "exo"
envs.register_environment("exo", Exo)
env = envs.get_environment(env_name)

In [None]:
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
        policy_hidden_layer_sizes=(64,64,64,64))
train_fn = functools.partial(
      ppo.train,
      num_timesteps=1000, num_evals=10, reward_scaling=1,
      episode_length=1000, normalize_observations=True,
      action_repeat=1, unroll_length=20, num_minibatches=8, gae_lambda=0.95,
      num_updates_per_batch=4, discounting=0.99, learning_rate=1e-4,
      entropy_cost=1e-2, num_envs=2048, batch_size=1024,
      network_factory=make_networks_factory,
      num_resets_per_eval=10,
      seed=0)

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  plt.ylim([min_y, max_y])

  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')

  plt.errorbar(
      x_data, y_data, yerr=ydataerr)
  plt.show()

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 1, 0

# Reset environments since internals may be overwritten by tracers from the
# domain randomization function.
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)
make_inference_fn, params, _= train_fn(environment=env,
                                       progress_fn=progress,
                                       eval_env=eval_env)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Save Model
model_path = 'policies/mjx_brax_policy'
model.save_params(model_path, params)

In [None]:
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

In [None]:

eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

rollout = []
actions = []
rng = jax.random.PRNGKey(0)
state = jit_env_reset(rng)

for _ in range(100):
    start = time.time()
    rollout.append(state)
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_env_step(state, act)
    end = time.time()
    print(f"step time: {end - start}")


In [None]:
import mujoco
from mujoco import mjx
def get_image(renderer, env, state, camera):
    d = mujoco.MjData(env.model)
    # write the mjx.Data into an mjData object
    mjx.device_get_into(d, state)
    mujoco.mj_forward(env.model, d)
    
    camera.lookat[0] = d.qpos[0]
    camera.lookat[1] = d.qpos[1]
    
    # use the mjData object to update the renderer
    
    renderer.update_scene(d, camera=camera)

    # time = d.time
    # curTime = f"Time = {time:.3f}"
    # mujoco.mjr_overlay(mujoco.mjtFont.mjFONT_NORMAL,mujoco.mjtGridPos.mjGRID_TOPLEFT,renderer._rect,curTime,'test',renderer._mjr_context)

    return renderer.render()

env.getRender()
images = []
for i in range(len(rollout)):
    temp_State = rollout[i].pipeline_state
    images.append(get_image(env.renderer,env,temp_State,env.camera))

media.show_video(images, fps=1.0 / env.dt)

In [None]:
output_file = 'exo_ppo.mp4'
# Save the video
media.write_video(output_file, images,fps=30) 