# Training in Brax

Once an environment is created in brax, we can quickly train it using brax's built-in training algorithms. Let's try it out!

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output,display, IFrame

import brax


import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

print("loaded on", jax.devices()[0].device_kind, "device.")

loaded on NVIDIA RTX A6000 device.


First let's pick an environment and a backend to train an agent in. 

Recall from the [Brax Basics](https://github.com/google/brax/blob/main/notebooks/basics.ipynb) colab, that the backend specifies which physics engine to use, each with different trade-offs between physical realism and training throughput/speed. The engines generally decrease in physical realism but increase in speed in the following order: `generalized`,  `positional`, then `spring`.


In [10]:
#@title Load Env { run: "auto" }

env_name = 'franka'
backend = 'generalized'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

html_file = f"{env_name}_preview.html"
with open(html_file, 'w') as f:
  f.write(html.render(env.sys, [state.pipeline_state],height=1600))




# Training

Brax provides out of the box the following training algorithms:

* [Proximal policy optimization](https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py)
* [Soft actor-critic](https://github.com/google/brax/blob/main/brax/training/agents/sac/train.py)
* [Evolutionary strategy](https://github.com/google/brax/blob/main/brax/training/agents/es/train.py)
* [Analytic policy gradients](https://github.com/google/brax/blob/main/brax/training/agents/apg/train.py)
* [Augmented random search](https://github.com/google/brax/blob/main/brax/training/agents/ars/train.py)

Trainers take as input an environment function and some hyperparameters, and return an inference function to operate the environment.

# Training

Let's train the Ant policy using the `generalized` backend with PPO.

In [15]:
#@title Visualizing a trajectory of the learned inference function

# create an env with auto-reset
env = envs.create(env_name=env_name, backend=backend)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
# jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state.pipeline_state)

  # Make action decision based on state
  act = jp.zeros(env.action_size, dtype=float)

  # Step environment
  state = jit_env_step(state, act)


# save html to file
html_file = f"{env_name}.html"
with open(html_file, 'w') as f:
  f.write(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))



🙌 See you soon!

In [43]:
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output,display, IFrame

import brax


import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

print("loaded on", jax.devices()[0].device_kind, "device.")

#@title Visualizing a trajectory of the learned inference function

# create an env with auto-reset
env = envs.create(env_name=env_name, backend=backend)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
# jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state.pipeline_state)

  # Make action decision based on state
  act = jp.ones(env.action_size, dtype=float)*0.1

  # Step environment
  state = jit_env_step(state, act)


# save html to file
html_file = f"{env_name}.html"
with open(html_file, 'w') as f:
  f.write(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))

loaded on NVIDIA RTX A6000 device.


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f328d3af610>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [17]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output,display, IFrame

import brax


import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac


class BraxHandler():

    def __init__(self, env_name, backend,rng_seed=0):
        self.env_name = env_name
        self.backend = backend
        # self.env = envs.get_environment(env_name=env_name,
        #                                 backend=backend)
        # self.state = jax.jit(self.env.reset)(rng=jax.random.PRNGKey(seed=0))

        self.env = envs.create(env_name=env_name, backend=backend)
        rng = jax.random.PRNGKey(seed=rng_seed)
        self.jit_env_set_and_step = jax.jit(self.env.set_state_and_step)
        self.jit_env_reset = jax.jit(self.env.reset)
        self.jit_env_rng = jax.random.PRNGKey(seed=1)
        self.init_state = self.jit_env_reset(rng=rng)
        self.rollout=[]
        self.rollout.append(self.init_state.pipeline_state)

    def perform_step(self,input_state, act, render=False):
        # self.rollout.append(self.init_state.pipeline_state)

        # Step environment
        state = self.jit_env_set_and_step(input_state, act)
        if render:
            self.rollout.append(input_state.pipeline_state)
        return state

    def get_rollout(self):
        return self.rollout
    
    def save_rollout(self, filename):
        with open(filename, 'w') as f:
            f.write(html.render(self.env.sys.tree_replace({'opt.timestep': self.env.dt}), self.rollout))
# for _ in range(1000):
#   rollout.append(state.pipeline_state)

#   # Make action decision based on state
#   act = jp.zeros(env.action_size, dtype=float)

#   # Step environment
#   state = jit_env_step(state, act)


# # save html to file
# html_file = f"{env_name}.html"
# with open(html_file, 'w') as f:
#   f.write(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))

# Test the class
env_name = 'walker2d_mpc'
backend = 'generalized'
brax_handler = BraxHandler(env_name, backend)
state = brax_handler.init_state
# init_q=brax_handler.env.sys.init_q
# init_qd=jp.zeros_like(init_q)
# print(state)
for _ in range(1000):
    act = jp.ones(brax_handler.env.action_size, dtype=float)
    state = brax_handler.perform_step(state, act,render=True)
    # init_q=state.pipeline_state.q
    # init_qd=state.pipeline_state.qd
    # print(init_q)
brax_handler.save_rollout(f"{env_name}_test.html")