<a href="https://colab.research.google.com/github/Denys88/rl_games/blob/master/notebooks/brax_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/Denys88/rl_games

Collecting git+https://github.com/Denys88/rl_games
  Cloning https://github.com/Denys88/rl_games to /tmp/pip-req-build-mr7os74g
  Running command git clone --filter=blob:none --quiet https://github.com/Denys88/rl_games /tmp/pip-req-build-mr7os74g
  Resolved https://github.com/Denys88/rl_games to commit 42c076edaf071e7f5a5f9154e1f3c1302c038aba
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting gym<0.24.0,>=0.23.0 (from gym[classic-control]<0.24.0,>=0.23.0->rl_games==1.6.1)
  Using cached gym-0.23.1.tar.gz (626 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting opencv-python<5.0.0,>=4.5.5 (from rl_games==1.6.1)
  Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting

In [2]:
#@title Brax training example
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs using a GPU runtime. From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

from datetime import datetime
import functools
import os

from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.io import html
from brax.io import model

In [3]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 4090 Laptop GPU (UUID: GPU-53f9f624-49f2-daaf-3d97-93c5ff2684f6)


In [4]:
%load_ext tensorboard

In [5]:
%tensorboard --logdir 'runs/'

In [6]:
## ant brax config:
ant_config = {'params': {'algo': {'name': 'a2c_continuous'},
  'config': {'bound_loss_type': 'regularisation',
   'bounds_loss_coef': 0.0,
   'clip_value': True,
   'critic_coef': 4,
   'e_clip': 0.2,
   'entropy_coef': 0.0,
   'env_config': {'env_name': 'ant', 'seed': 5},
   'env_name': 'brax',
   'gamma': 0.99,
   'grad_norm': 1.0,
   'horizon_length': 8,
   'kl_threshold': 0.008,
   'learning_rate': '3e-4',
   'lr_schedule': 'adaptive',
   'max_epochs': 5000,
   'mini_epochs': 4,
   'minibatch_size': 32768,
   'name': 'ant-brax',
   'normalize_advantage': True,
   'normalize_input': True,
   'normalize_value': True,
   'num_actors': 4096,
   'player': {'render': True},
   'ppo': True,
   'reward_shaper': {'scale_value': 0.1},
   'schedule_type': 'standard',
   'score_to_win': 20000,
   'tau': 0.95,
   'truncate_grads': True,
   'use_smooth_clamp': True,
   'value_bootstrap': True},
  'model': {'name': 'continuous_a2c_logstd'},
  'network': {'mlp': {'activation': 'elu',
    'initializer': {'name': 'default'},
    'units': [256, 128, 64]},
   'name': 'actor_critic',
   'separate': False,
   'space': {'continuous': {'fixed_sigma': True,
     'mu_activation': 'None',
     'mu_init': {'name': 'default'},
     'sigma_activation': 'None',
     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},
  'seed': 5}}

In [7]:
## config from the openai gym mujoco (should have the same network and normalization) to render result:
humanoid_config = {'params': {'algo': {'name': 'a2c_continuous'},
  'config': {'bound_loss_type': 'regularisation',
   'bounds_loss_coef': 0.0,
   'clip_value': True,
   'critic_coef': 4,
   'e_clip': 0.2,
   'entropy_coef': 0.0,
   'env_config': {'env_name': 'humanoid', 'seed': 5},
   'env_name': 'brax',
   'gamma': 0.99,
   'grad_norm': 1.0,
   'horizon_length': 16,
   'kl_threshold': 0.008,
   'learning_rate': '3e-4',
   'lr_schedule': 'adaptive',
   'max_epochs': 5000,
   'mini_epochs': 5,
   'minibatch_size': 32768,
   'name': 'humanoid-brax',
   'normalize_advantage': True,
   'normalize_input': True,
   'normalize_value': True,
   'num_actors': 4096,
   'player': {'render': True},
   'ppo': True,
   'reward_shaper': {'scale_value': 0.1},
   'schedule_type': 'standard',
   'score_to_win': 20000,
   'tau': 0.95,
   'truncate_grads': True,
   'use_smooth_clamp': True,
   'value_bootstrap': True},
  'model': {'name': 'continuous_a2c_logstd'},
  'network': {'mlp': {'activation': 'elu',
    'initializer': {'name': 'default'},
    'units': [512, 256, 128]},
   'name': 'actor_critic',
   'separate': False,
   'space': {'continuous': {'fixed_sigma': True,
     'mu_activation': 'None',
     'mu_init': {'name': 'default'},
     'sigma_activation': 'None',
     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},
  'seed': 5}}

In [8]:
import yaml
from rl_games.torch_runner import Runner

env_name = 'ant'  # @param ['ant', 'humanoid']
configs = {
    'ant' : ant_config,
    'humanoid' : humanoid_config
}
networks = {
    'ant' : 'runs/ant/nn/ant-brax.pth',
    'humanoid' : 'runs/humanoid/nn/humanoid-brax.pth'
}

config = configs[env_name]
network_path = networks[env_name]
config['params']['config']['full_experiment_name'] = env_name
config['params']['config']['max_epochs'] = 1000

In [9]:
runner = Runner()
runner.load(config)
runner.run({
    'train': True,
})

self.seed = 5
Started to train
Exact experiment name requested from command line: ant


TypeError: BraxEnv.__init__() got multiple values for argument 'env_name'

In [None]:
from rl_games.envs.brax import BraxEnv

from IPython.display import HTML, IFrame, display, clear_output
import os

In [None]:
agent = runner.create_player()
agent.restore(network_path)

env_config = runner.params['config']['env_config']
num_actors = 1
env = BraxEnv('', num_actors, **env_config)

In [None]:
qps = []
obs = env.reset()
total_reward = 0
num_steps = 0

class QP:
    def __init__(self, qp):
        self.pos = jax.numpy.squeeze(qp.pos, axis=0)
        self.rot = jax.numpy.squeeze(qp.rot, axis=0)

is_done = False
while not is_done:
    qps.append(QP(env.env._state.qp))
    act = agent.get_action(obs)
    obs, reward, is_done, info = env.step(act.unsqueeze(0))
    total_reward += reward.item()
    num_steps += 1

print('Total Reward: ', total_reward)
print('Num steps: ', num_steps)

In [None]:
def visualize(sys, qps):
    return HTML(html.render(sys, qps))

In [None]:
display(visualize(env.env._env.sys, qps))