In [None]:
from IPython.display import clear_output, HTML
import functools
import matplotlib.pyplot as plt
from tqdm import tqdm

import jax
import jax.numpy as jnp

import brax.training.agents.ppo.train as ppo
import brax.training.agents.es.train as es
from brax.io import model
from brax.io import html

from rl_racer.envs.v1.hover import HoverV1

In [None]:
EPISODE_LENGTH = 200

In [None]:
train_ppo = functools.partial(
    ppo.train,
    num_timesteps=1_000_000,
    num_evals=10,
    reward_scaling=1,
    episode_length=EPISODE_LENGTH,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=5,
    num_minibatches=32,
    num_updates_per_batch=4,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-2,
    num_envs=4096,
    batch_size=2048,
    seed=42,
)

In [None]:
train_es = functools.partial(
    es.train,
    num_timesteps=1_000_000,
    num_evals=10,
    episode_length=EPISODE_LENGTH,
    normalize_observations=True,
    seed=1,
)

In [None]:
train_fn = train_ppo

In [None]:
env = HoverV1()
reset = jax.jit(env.reset)
step = jax.jit(env.step)

In [None]:
xdata, ydata = [], []

def progress(num_steps, metrics):
  xdata.append(num_steps)
  ydata.append(metrics['eval/episode_reward'])
  clear_output(wait=True)
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(xdata, ydata)
  plt.show()

make_inference_fn, params, metrics = train_fn(environment=env, progress_fn=progress)

In [None]:
model.save_params('trained_params', params)

In [None]:
inference_fn = jax.jit(make_inference_fn(params))

In [None]:
rollout = []
rng = jax.random.PRNGKey(seed=42)
state = reset(rng=rng)
for _ in tqdm(range(EPISODE_LENGTH)):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act, _ = inference_fn(state.obs, act_rng)
  state = step(state, act)

doc = html.render(env.sys.replace(dt=env.dt), rollout)
with open('trained.html', 'w') as f:
  f.write(doc)
HTML(doc)